[Mlir-commits] [mlir] [mlir][mesh]fixes for 0d tensors (PR #132948)

Frank Schlimbach llvmlistbot at llvm.org
Wed Mar 26 10:14:55 PDT 2025


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/132948

>From f54a37cb31285aa94fae0c60e0c7560841d1708d Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 25 Mar 2025 16:32:39 +0100
Subject: [PATCH 1/2] fixes for 0d tensors

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  2 ++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  2 +-
 .../Mesh/Interfaces/ShardingInterface.cpp     |  2 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 29 ++++++++++++++-----
 .../Extensions/MeshShardingExtensions.cpp     | 18 ++++++++----
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  7 +++++
 6 files changed, 44 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fc5cfffea27a7..32c2eca2cefa8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -119,6 +119,8 @@ inline bool isFullReplication(MeshSharding sharding) {
 inline mesh::MeshOp
 getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
               SymbolTableCollection &symbolTableCollection) {
+  if (!meshSymbol)
+    return nullptr;
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
       op, meshSymbol);
 }
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3e9f86fde64f3..65475b69dbdb1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -269,7 +269,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
 
 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
-  if (rankedTensorType) {
+  if (rankedTensorType && !rankedTensorType.getShape().empty()) {
     return shardShapedType(rankedTensorType, mesh, sharding);
   }
   return type;
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index f427d004c558f..8aaa0704119a8 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -718,6 +718,6 @@ void mesh::spmdizeTriviallyShardableOperation(
        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
     newResult.setType(
         shardType(newResult.getType(),
-                  getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
+                  getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
   }
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 601af0200e785..b6cb06ae3170f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
       block.getArguments(), std::back_inserter(res),
       [&symbolTableCollection](BlockArgument arg) {
         auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
-        if (!rankedTensorArg) {
+        if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
           return arg.getType();
         }
 
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
   llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
     TypedValue<RankedTensorType> rankedTensor =
         dyn_cast<TypedValue<RankedTensorType>>(operand);
-    if (!rankedTensor) {
+    if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
       return MeshSharding();
     }
 
@@ -690,18 +690,31 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
   std::vector<MeshSharding> res;
   res.reserve(op.getNumResults());
   llvm::transform(op.getResults(), std::back_inserter(res),
-                  [](OpResult result) {
+                  [&op](OpResult result) {
+                    if (!result.hasOneUse() || result.use_empty()) {
+                      return MeshSharding();
+                    }
                     TypedValue<RankedTensorType> rankedTensor =
                         dyn_cast<TypedValue<RankedTensorType>>(result);
                     if (!rankedTensor) {
                       return MeshSharding();
                     }
-                    if (!result.hasOneUse()) {
-                      return MeshSharding();
-                    }
                     Operation *userOp = *result.getUsers().begin();
-                    ShardOp shardOp = llvm::cast<ShardOp>(userOp);
-                    return MeshSharding(shardOp.getSharding());
+                    ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
+                    if (shardOp) {
+                      return MeshSharding(shardOp.getSharding());
+                    }
+                    if (rankedTensor.getType().getRank() == 0) {
+                      // This is a 0d tensor result without explicit sharding.
+                      // Find mesh symbol from operands, if any.
+                      // Shardings without mesh are not always fully supported yet.
+                      for (auto operand: op.getOperands()) {
+                        if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
+                          return MeshSharding(sharding.getMeshAttr());
+                        }
+                      }
+                    }
+                    return MeshSharding();
                   });
   return res;
 }
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b3d69eb5e1a23..fc93f1c1c9220 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,19 +50,25 @@ struct CreatorOpShardingInterface
                         IRMapping &spmdizationMap,
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
-    auto mesh =
-        mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
-    auto shardType = cast<ShapedType>(
-        mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
+    assert(resultShardings.size() == 1);
+    auto resType = cast<RankedTensorType>(op->getResult(0).getType());
+    mlir::mesh::MeshOp mesh;
+    ShapedType shardType;
+    if (resType.getRank() > 0) {
+      mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+      shardType =
+          cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
+    } else {
+      shardType = resType;
+    }
     Operation *newOp = nullptr;
     // if the sharding introduces a new dynamic dimension, we take it from
     // the dynamic sharding info. For now bail out if it's not
     // provided.
-    assert(resultShardings.size() == 1);
     if (!shardType.hasStaticShape()) {
       assert(op->getResult(0).hasOneUse());
       SmallVector<Value> newOperands;
-      auto oldType = cast<ShapedType>(op->getResult(0).getType());
+      auto oldType = cast<ShapedType>(resType);
       assert(oldType.getRank() == shardType.getRank());
       int currOldOprndNum = -1;
       mesh::ShardShapeOp shapeForDevice;
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 01cf5972177f4..3fb8424745501 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -43,3 +43,10 @@ func.func @tensor_empty_same_static_dims_sizes() -> () {
 
   return
 }
+
+// CHECK-LABEL: func @tensor_empty_0d
+func.func @tensor_empty_0d() -> () {
+  tensor.empty() : tensor<f32>
+  // CHECK-NEXT:  tensor.empty() : tensor<f32>
+  return
+}

>From 30fbf18a57f9853c7f4fec371504af53c68bc764 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 25 Mar 2025 19:35:01 +0100
Subject: [PATCH 2/2] clang-format

---
 .../Mesh/Interfaces/ShardingInterface.cpp     |  6 +--
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 54 +++++++++----------
 2 files changed, 30 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 8aaa0704119a8..7b3107a6e6204 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -716,8 +716,8 @@ void mesh::spmdizeTriviallyShardableOperation(
   // Set the result types to the sharded counterparts.
   for (auto [oldResult, newResult, sharding] :
        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
-    newResult.setType(
-        shardType(newResult.getType(),
-                  getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
+    newResult.setType(shardType(
+        newResult.getType(),
+        getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
   }
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b6cb06ae3170f..69a80b1a05bf4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -689,33 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
 static std::vector<MeshSharding> getResultShardings(Operation &op) {
   std::vector<MeshSharding> res;
   res.reserve(op.getNumResults());
-  llvm::transform(op.getResults(), std::back_inserter(res),
-                  [&op](OpResult result) {
-                    if (!result.hasOneUse() || result.use_empty()) {
-                      return MeshSharding();
-                    }
-                    TypedValue<RankedTensorType> rankedTensor =
-                        dyn_cast<TypedValue<RankedTensorType>>(result);
-                    if (!rankedTensor) {
-                      return MeshSharding();
-                    }
-                    Operation *userOp = *result.getUsers().begin();
-                    ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
-                    if (shardOp) {
-                      return MeshSharding(shardOp.getSharding());
-                    }
-                    if (rankedTensor.getType().getRank() == 0) {
-                      // This is a 0d tensor result without explicit sharding.
-                      // Find mesh symbol from operands, if any.
-                      // Shardings without mesh are not always fully supported yet.
-                      for (auto operand: op.getOperands()) {
-                        if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
-                          return MeshSharding(sharding.getMeshAttr());
-                        }
-                      }
-                    }
-                    return MeshSharding();
-                  });
+  llvm::transform(
+      op.getResults(), std::back_inserter(res), [&op](OpResult result) {
+        if (!result.hasOneUse() || result.use_empty()) {
+          return MeshSharding();
+        }
+        TypedValue<RankedTensorType> rankedTensor =
+            dyn_cast<TypedValue<RankedTensorType>>(result);
+        if (!rankedTensor) {
+          return MeshSharding();
+        }
+        Operation *userOp = *result.getUsers().begin();
+        ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
+        if (shardOp) {
+          return MeshSharding(shardOp.getSharding());
+        }
+        if (rankedTensor.getType().getRank() == 0) {
+          // This is a 0d tensor result without explicit sharding.
+          // Find mesh symbol from operands, if any.
+          // Shardings without mesh are not always fully supported yet.
+          for (auto operand : op.getOperands()) {
+            if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
+              return MeshSharding(sharding.getMeshAttr());
+            }
+          }
+        }
+        return MeshSharding();
+      });
   return res;
 }
 



More information about the Mlir-commits mailing list