[Mlir-commits] [mlir] Add OneTypedResult::getResultOfType to simplify the result type casting logic (PR #120381)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 18 18:17:48 PST 2024


https://github.com/xiaoleis-nv updated https://github.com/llvm/llvm-project/pull/120381

>From 1012bf39696678453d48e48cd26ad424809a7d88 Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Wed, 18 Dec 2024 00:14:48 -0800
Subject: [PATCH 1/3] add getResultOfType

---
 mlir/include/mlir/IR/OpDefinition.h           |  9 +++++++--
 .../Transforms/RuntimeOpVerification.cpp      |  2 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 20 +++++++++++--------
 3 files changed, 20 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d6690991..ae28e1251bd954 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -694,11 +694,16 @@ class OneTypedResult {
   class Impl
       : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
   public:
-    mlir::TypedValue<ResultType> getResult() {
-      return cast<mlir::TypedValue<ResultType>>(
+    template <typename ValTy>
+    mlir::TypedValue<ValTy> getResultOfType() {
+      return mlir::cast<mlir::TypedValue<ValTy>>(
           this->getOperation()->getResult(0));
     }
 
+    mlir::TypedValue<ResultType> getResult() {
+      return getResultOfType<ResultType>();
+    }
+
     /// If the operation returns a single value, then the Op can be implicitly
     /// converted to a Value. This yields the value of the only result.
     operator mlir::TypedValue<ResultType>() { return getResult(); }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7f..6d5a68ef4d0add 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -213,7 +213,7 @@ struct ReinterpretCastOpInterface
     auto reinterpretCast = cast<ReinterpretCastOp>(op);
     auto baseMemref = reinterpretCast.getSource();
     auto resultMemref =
-        cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
+        reinterpretCast.getResultOfType<BaseMemRefType>();
 
     builder.setInsertionPointAfter(op);
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 327ea0991e4e1e..5c268c06db08e6 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -85,13 +85,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
   }
 
   builder.setInsertionPointAfterValue(sourceShard);
-  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
+  TypedValue<ShapedType> resultValue =
       builder
           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
                                sourceSharding.getMeshAttr().getLeafReference(),
                                allReduceMeshAxes, sourceShard,
                                sourceSharding.getPartialType())
-          .getResult());
+          .getResultOfType<ShapedType>();
 
   llvm::SmallVector<MeshAxis> remainingPartialAxes;
   llvm::copy_if(sourceShardingPartialAxesSet,
@@ -133,12 +133,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshSharding sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+  TypedValue<ShapedType> targetShard =
       builder
           .create<AllSliceOp>(sourceShard, mesh,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
-          .getResult());
+          .getResultOfType<ShapedType>();
   MeshSharding targetSharding = targetShardingInSplitLastAxis(
       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
   return {targetShard, targetSharding};
@@ -274,8 +274,10 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
       APInt(64, splitTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
+  TypedValue<ShapedType> targetShard =
+      builder
+          .create<tensor::CastOp>(targetShape, allGatherResult)
+          .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }
 
@@ -407,8 +409,10 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
+  TypedValue<ShapedType> targetShard =
+      builder
+          .create<tensor::CastOp>(targetShape, allToAllResult)
+          .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }
 

>From c556755da0e41cb611b5d349ff83d121337d980e Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Wed, 18 Dec 2024 00:39:39 -0800
Subject: [PATCH 2/3] fix format issue

---
 .../lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp | 3 +--
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp            | 6 ++----
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 6d5a68ef4d0add..5ca7108f79a920 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -212,8 +212,7 @@ struct ReinterpretCastOpInterface
                                    Location loc) const {
     auto reinterpretCast = cast<ReinterpretCastOp>(op);
     auto baseMemref = reinterpretCast.getSource();
-    auto resultMemref =
-        reinterpretCast.getResultOfType<BaseMemRefType>();
+    auto resultMemref = reinterpretCast.getResultOfType<BaseMemRefType>();
 
     builder.setInsertionPointAfter(op);
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 5c268c06db08e6..6c41ca8edc0936 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -275,8 +275,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
   TypedValue<ShapedType> targetShard =
-      builder
-          .create<tensor::CastOp>(targetShape, allGatherResult)
+      builder.create<tensor::CastOp>(targetShape, allGatherResult)
           .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }
@@ -410,8 +409,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
   TypedValue<ShapedType> targetShard =
-      builder
-          .create<tensor::CastOp>(targetShape, allToAllResult)
+      builder.create<tensor::CastOp>(targetShape, allToAllResult)
           .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }

>From 55bc56b8cf537c9c549ee5e03c6c2bc429545d49 Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Wed, 18 Dec 2024 18:17:21 -0800
Subject: [PATCH 3/3] rename getResultOfType to getResultAs

---
 mlir/include/mlir/IR/OpDefinition.h                       | 4 ++--
 .../Dialect/MemRef/Transforms/RuntimeOpVerification.cpp   | 2 +-
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp          | 8 ++++----
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index ae28e1251bd954..827274f09b4b14 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -695,13 +695,13 @@ class OneTypedResult {
       : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
   public:
     template <typename ValTy>
-    mlir::TypedValue<ValTy> getResultOfType() {
+    mlir::TypedValue<ValTy> getResultAs() {
       return mlir::cast<mlir::TypedValue<ValTy>>(
           this->getOperation()->getResult(0));
     }
 
     mlir::TypedValue<ResultType> getResult() {
-      return getResultOfType<ResultType>();
+      return getResultAs<ResultType>();
     }
 
     /// If the operation returns a single value, then the Op can be implicitly
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 5ca7108f79a920..1a852ed05096ad 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -212,7 +212,7 @@ struct ReinterpretCastOpInterface
                                    Location loc) const {
     auto reinterpretCast = cast<ReinterpretCastOp>(op);
     auto baseMemref = reinterpretCast.getSource();
-    auto resultMemref = reinterpretCast.getResultOfType<BaseMemRefType>();
+    auto resultMemref = reinterpretCast.getResultAs<BaseMemRefType>();
 
     builder.setInsertionPointAfter(op);
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6c41ca8edc0936..2f1003766dabd5 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -91,7 +91,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
                                sourceSharding.getMeshAttr().getLeafReference(),
                                allReduceMeshAxes, sourceShard,
                                sourceSharding.getPartialType())
-          .getResultOfType<ShapedType>();
+          .getResultAs<ShapedType>();
 
   llvm::SmallVector<MeshAxis> remainingPartialAxes;
   llvm::copy_if(sourceShardingPartialAxesSet,
@@ -138,7 +138,7 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
           .create<AllSliceOp>(sourceShard, mesh,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
-          .getResultOfType<ShapedType>();
+          .getResultAs<ShapedType>();
   MeshSharding targetSharding = targetShardingInSplitLastAxis(
       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
   return {targetShard, targetSharding};
@@ -276,7 +276,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
   TypedValue<ShapedType> targetShard =
       builder.create<tensor::CastOp>(targetShape, allGatherResult)
-          .getResultOfType<ShapedType>();
+          .getResultAs<ShapedType>();
   return {targetShard, targetSharding};
 }
 
@@ -410,7 +410,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
   TypedValue<ShapedType> targetShard =
       builder.create<tensor::CastOp>(targetShape, allToAllResult)
-          .getResultOfType<ShapedType>();
+          .getResultAs<ShapedType>();
   return {targetShard, targetSharding};
 }
 



More information about the Mlir-commits mailing list