[Mlir-commits] [mlir] fixes for 0d tensors (PR #132948)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Mar 25 08:48:40 PDT 2025
https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/132948
0d tensors are generally treated as scalars, e.g. they are always replicated.
In some cases 0d tensors have no sharding. This PR provides a few minor fixes to account for such cases.
@tkarna Could you pleas have a look at this?
>From f54a37cb31285aa94fae0c60e0c7560841d1708d Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 25 Mar 2025 16:32:39 +0100
Subject: [PATCH] fixes for 0d tensors
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 2 ++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 2 +-
.../Mesh/Interfaces/ShardingInterface.cpp | 2 +-
.../Dialect/Mesh/Transforms/Spmdization.cpp | 29 ++++++++++++++-----
.../Extensions/MeshShardingExtensions.cpp | 18 ++++++++----
.../test/Dialect/Tensor/mesh-spmdization.mlir | 7 +++++
6 files changed, 44 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index fc5cfffea27a7..32c2eca2cefa8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -119,6 +119,8 @@ inline bool isFullReplication(MeshSharding sharding) {
inline mesh::MeshOp
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) {
+ if (!meshSymbol)
+ return nullptr;
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
op, meshSymbol);
}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3e9f86fde64f3..65475b69dbdb1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -269,7 +269,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
- if (rankedTensorType) {
+ if (rankedTensorType && !rankedTensorType.getShape().empty()) {
return shardShapedType(rankedTensorType, mesh, sharding);
}
return type;
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index f427d004c558f..8aaa0704119a8 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -718,6 +718,6 @@ void mesh::spmdizeTriviallyShardableOperation(
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(
shardType(newResult.getType(),
- getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
+ getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
}
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 601af0200e785..b6cb06ae3170f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
- if (!rankedTensorArg) {
+ if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
return arg.getType();
}
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(operand);
- if (!rankedTensor) {
+ if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
return MeshSharding();
}
@@ -690,18 +690,31 @@ static std::vector<MeshSharding> getResultShardings(Operation &op) {
std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
llvm::transform(op.getResults(), std::back_inserter(res),
- [](OpResult result) {
+ [&op](OpResult result) {
+ if (!result.hasOneUse() || result.use_empty()) {
+ return MeshSharding();
+ }
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
return MeshSharding();
}
- if (!result.hasOneUse()) {
- return MeshSharding();
- }
Operation *userOp = *result.getUsers().begin();
- ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- return MeshSharding(shardOp.getSharding());
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
+ if (shardOp) {
+ return MeshSharding(shardOp.getSharding());
+ }
+ if (rankedTensor.getType().getRank() == 0) {
+ // This is a 0d tensor result without explicit sharding.
+ // Find mesh symbol from operands, if any.
+ // Shardings without mesh are not always fully supported yet.
+ for (auto operand: op.getOperands()) {
+ if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
+ return MeshSharding(sharding.getMeshAttr());
+ }
+ }
+ }
+ return MeshSharding();
});
return res;
}
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b3d69eb5e1a23..fc93f1c1c9220 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,19 +50,25 @@ struct CreatorOpShardingInterface
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
- auto mesh =
- mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
- auto shardType = cast<ShapedType>(
- mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
+ assert(resultShardings.size() == 1);
+ auto resType = cast<RankedTensorType>(op->getResult(0).getType());
+ mlir::mesh::MeshOp mesh;
+ ShapedType shardType;
+ if (resType.getRank() > 0) {
+ mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+ shardType =
+ cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
+ } else {
+ shardType = resType;
+ }
Operation *newOp = nullptr;
// if the sharding introduces a new dynamic dimension, we take it from
// the dynamic sharding info. For now bail out if it's not
// provided.
- assert(resultShardings.size() == 1);
if (!shardType.hasStaticShape()) {
assert(op->getResult(0).hasOneUse());
SmallVector<Value> newOperands;
- auto oldType = cast<ShapedType>(op->getResult(0).getType());
+ auto oldType = cast<ShapedType>(resType);
assert(oldType.getRank() == shardType.getRank());
int currOldOprndNum = -1;
mesh::ShardShapeOp shapeForDevice;
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 01cf5972177f4..3fb8424745501 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -43,3 +43,10 @@ func.func @tensor_empty_same_static_dims_sizes() -> () {
return
}
+
+// CHECK-LABEL: func @tensor_empty_0d
+func.func @tensor_empty_0d() -> () {
+ tensor.empty() : tensor<f32>
+ // CHECK-NEXT: tensor.empty() : tensor<f32>
+ return
+}
More information about the Mlir-commits
mailing list