[llvm] [mlir] [mlir] Implement Mesh's ShardingInterface for Linalg ops (PR #82284)

Boian Petkantchin via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 07:20:07 PST 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/82284

>From 8785fff09e992a8453165fff485f1d991c459e87 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Mon, 12 Feb 2024 09:57:16 -0800
Subject: [PATCH 1/2] [mlir] Implement Mesh's ShardingInterface for Linalg ops

Allows linalg structured operations to be handled during spmdization and
sharding propagation.

There is only support for projected permutation indexing maps.
---
 .../Dialect/Linalg/Transforms/AllInterfaces.h |  26 ++
 .../Transforms/MeshShardingInterfaceImpl.h    |  20 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |   6 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |   4 +
 .../Mesh/Interfaces/ShardingInterfaceImpl.h   |  18 +
 .../mlir/Dialect/Mesh/Transforms/Transforms.h |   6 +
 mlir/include/mlir/InitAllDialects.h           |  10 +-
 .../Linalg/Transforms/AllInterfaces.cpp       |  24 ++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   5 +
 .../Transforms/MeshShardingInterfaceImpl.cpp  | 336 ++++++++++++++++++
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |   8 -
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |   7 +
 .../Mesh/Interfaces/ShardingInterface.cpp     |  79 ++++
 .../Dialect/Mesh/Transforms/Transforms.cpp    |  13 +
 .../test/Dialect/Linalg/mesh-spmdization.mlir | 165 +++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |   3 +
 16 files changed, 714 insertions(+), 16 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
 create mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/Linalg/mesh-spmdization.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
new file mode 100644
index 00000000000000..a69751e072b797
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
@@ -0,0 +1,26 @@
+//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all external
+// interface implementations to the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerAllDialectInterfaceImplementations(DialectRegistry &registry);
+} // namespace linalg
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
new file mode 100644
index 00000000000000..c57501ea86b7ed
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index fc2acc70381ef7..9d9b5892e1a51f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
   I32EnumAttrCase<"Sum", 1, "sum">,
   I32EnumAttrCase<"Max", 2, "max">,
   I32EnumAttrCase<"Min", 3, "min">,
+  I32EnumAttrCase<"Product", 4, "product">,
+  // Arithmetic mean.
+  I32EnumAttrCase<"Average", 5, "average">,
+  I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
+  I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
+  I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
   I32EnumAttrCase<"Generic", 100, "generic">
 ]> {
   let genSpecializedAttr = 0;
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..19020c29459821 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -340,6 +340,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
+    let builders = [
+    OpBuilder<(ins "Value":$input, "StringRef":$mesh,
+      "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
+  ];
 }
 
 def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index ffc9b6fb18be53..ab4df2ab028d43 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -22,6 +22,24 @@ class SymbolTableCollection;
 
 namespace mesh {
 
+// Retrieve the mesh axes corresponding to each operation loop iterator based
+// on the provided shardings for the op's operands and results.
+// Assumes that the indexingMaps are projected permutations.
+ShardingArray getMeshAxisAssignmentForLoopIterators(
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<AffineMap> indexingMaps);
+
+bool isAtLeastOneReductionIteratorSharded(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
+// Get the set of mesh axes that correspond to reduction loop iterators.
+SmallVector<MeshAxis> getReductionMeshAxes(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
 // Inserts a clone of the operation that has all ranked tensor
 // arguments/results sharded.
 void spmdizeTriviallyShardableOperation(
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index aeab28961a4e1e..be82e2af399dc8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -13,6 +13,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
 class RewritePatternSet;
@@ -37,6 +38,11 @@ TypedValue<IndexType>
 createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
                                  ImplicitLocOpBuilder &builder);
 
+// Get process linear index along the given mesh axes.
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+                                               ArrayRef<MeshAxis> meshAxes,
+                                               ImplicitLocOpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e508d51205f347..04fc0f906a8fc4 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -43,10 +43,7 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
@@ -155,10 +152,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   cf::registerBufferizableOpInterfaceExternalModels(registry);
   cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
-  linalg::registerBufferizableOpInterfaceExternalModels(registry);
-  linalg::registerSubsetOpInterfaceExternalModels(registry);
-  linalg::registerTilingInterfaceExternalModels(registry);
-  linalg::registerValueBoundsOpInterfaceExternalModels(registry);
+  linalg::registerAllDialectInterfaceImplementations(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
new file mode 100644
index 00000000000000..cc9f8d23231ee1
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
@@ -0,0 +1,24 @@
+//===- AllInterfaces.cpp - --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+
+#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+void mlir::linalg::registerAllDialectInterfaceImplementations(
+    DialectRegistry &registry) {
+  registerBufferizableOpInterfaceExternalModels(registry);
+  registerMeshShardingInterfaceExternalModels(registry);
+  registerSubsetOpInterfaceExternalModels(registry);
+  registerTilingInterfaceExternalModels(registry);
+  registerValueBoundsOpInterfaceExternalModels(registry);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4f47e3b8718454..513c54de5d7bfc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRLinalgTransforms
+  AllInterfaces.cpp
   BubbleUpExtractSlice.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
   Promotion.cpp
@@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRIR
   MLIRMemRefDialect
   MLIRMemRefTransforms
+  MLIRMeshDialect
+  MLIRMeshTransforms
   MLIRLinalgDialect
   MLIRLinalgUtils
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRSCFUtils
   MLIRPass
+  MLIRShardingInterface
   MLIRSubsetOpInterface
   MLIRSparseTensorDialect
   MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
new file mode 100644
index 00000000000000..621885974b2ef3
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -0,0 +1,336 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
+static ReductionKind getReductionKind(Operation *op) {
+  return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+      // Floating-point operations.
+      .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+      .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+      // Integer operations.
+      .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+      .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+      .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+      .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+      .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+      .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+      .Default([](Operation *op) { return ReductionKind::Generic; });
+}
+
+static std::optional<Operation *> getReductionOp(LinalgOp op) {
+  SmallVector<Operation *> combinerOps;
+  Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
+  if (!reducedValue || combinerOps.size() != 1) {
+    return std::nullopt;
+  }
+
+  return combinerOps[0];
+}
+
+static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
+  std::optional<Operation *> reductionOp = getReductionOp(op);
+  if (!reductionOp) {
+    return ReductionKind::Generic;
+  }
+  return getReductionKind(reductionOp.value());
+}
+
+static MeshOp getMesh(Operation *op,
+                      ArrayRef<MeshShardingAttr> operandShardings,
+                      ArrayRef<MeshShardingAttr> resultShardings,
+                      SymbolTableCollection &symbolTable) {
+  for (MeshShardingAttr sharding : operandShardings) {
+    if (sharding) {
+      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+    }
+  }
+
+  for (MeshShardingAttr sharding : resultShardings) {
+    if (sharding) {
+      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+    }
+  }
+
+  assert(false);
+}
+
+// Choose the operand based on the current process index along the reduction
+// mesh axes.
+// We need to use the initial value only once to avoid including it in the
+// reduction multiple times.
+// In each process group only the leading process with linear index 0 would use
+// the original operand.
+// The other processes would use the reduction operation neutral tensor.
+static Value createDestinationPassingStyleInitOperand(
+    LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
+    MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
+      meshOp.getSymName(), reductionMeshAxes, builder);
+  Value zero = builder.create<arith::ConstantIndexOp>(0);
+  Value isLeadProcess = builder.create<arith::CmpIOp>(
+      builder.getI1Type(), arith::CmpIPredicate::eq,
+      processLinearIndexInReductionGroup, zero);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+                                             isLeadProcess, true, true);
+  // Then block.
+  {
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+    builder.create<scf::YieldOp>(spmdizedOperand);
+  }
+
+  // Else block.
+  {
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
+    SmallVector<OpFoldResult> shape =
+        tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+    PartialReductionOpInterface partialReductionIface =
+        llvm::cast<PartialReductionOpInterface>(op.getOperation());
+    FailureOr<Operation *> reductionNeutralTensorOp =
+        partialReductionIface.generateInitialTensorForPartialReduction(
+            builder, builder.getLoc(), shape, {});
+    assert(succeeded(reductionNeutralTensorOp));
+    builder.create<scf::YieldOp>(
+        reductionNeutralTensorOp.value()->getResult(0));
+  }
+  return ifOp.getResult(0);
+}
+
+// Create the DPS init operands for the spmdized Linalg op.
+// Return all the new spmdized operands.
+static SmallVector<Value> createDestinationPassingStyleInitOperands(
+    LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+    ImplicitLocOpBuilder &builder) {
+  // TODO: add support for multiple destination passing style initial value
+  // operands.
+  // PartialReductionOpInterface::generateInitialTensorForPartialReduction
+  // needs to also support multiple DPS initial operands.
+  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+  auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
+  Value spmdizedInitOperand =
+      spmdizationMap.lookup(op->getOperands()[operandIdx]);
+  newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
+      op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+  return newOperands;
+}
+
+static void createAllReduceForResultWithoutPartialSharding(
+    Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
+    MeshShardingAttr resultSharding, ReductionKind reductionKind,
+    IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
+  SmallVector<MeshAxis> allReduceMeshAxes;
+  llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
+                [&resultSharding](MeshAxis axis) {
+                  return !llvm::is_contained(resultSharding.getPartialAxes(),
+                                             axis);
+                });
+  if (allReduceMeshAxes.empty()) {
+    return;
+  }
+
+  Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
+  Value reducedValue = builder.create<mesh::AllReduceOp>(
+      spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+      allReduceMeshAxes, reductionKind);
+  spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+}
+
+static void createAllReduceForResultsWithoutPartialShardings(
+    LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
+    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ImplicitLocOpBuilder &builder) {
+  ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
+  for (auto [unshardedLinalgOpResult, resultSharding] :
+       llvm::zip(unshardedOp->getResults(), resultShardings)) {
+    createAllReduceForResultWithoutPartialSharding(
+        unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
+        reductionKind, spmdizationMap, builder);
+  }
+}
+
+static void spmdizeLinalgOpWithShardedReduction(
+    LinalgOp op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
+    IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+    ImplicitLocOpBuilder &builder) {
+  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
+  SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
+      loopIteratorTypes, meshAxisAssignmentForLoopIterators);
+  SmallVector<Value> spmdizedLinalgOpOperands =
+      createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
+                                                reductionMeshAxes,
+                                                spmdizationMap, builder);
+  // We must not change the operand mappings of the original spmdizationMap as
+  // they are the mappings for the whole spmdization blob and may be used by
+  // others.
+  IRMapping internalSpmdizationMap;
+  for (auto [unshardedOperand, spmdizedOperand] :
+       llvm::zip(op->getOperands(), spmdizedLinalgOpOperands)) {
+    internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+  }
+  spmdizeTriviallyShardableOperation(
+      *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
+      internalSpmdizationMap, symbolTable, builder);
+  for (Value result : op->getResults()) {
+    spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+  }
+
+  // Handle partial shardings.
+  createAllReduceForResultsWithoutPartialShardings(
+      op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+}
+
+namespace {
+
+// ShardingInterface for ops that implement LinalgStructuredInterface.
+// The supported ops are only those where the indexing maps are projected
+// permutations.
+template <typename Op>
+struct StructuredOpShardingInterface
+    : public mesh::ShardingInterface::ExternalModel<
+          StructuredOpShardingInterface<Op>, Op> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+    SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
+
+    // Results must have the same indexing as destination passing style initial
+    // operands.
+    for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
+      res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
+    }
+
+    return res;
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshShardingAttr> operandShardings,
+                        ArrayRef<MeshShardingAttr> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+
+    SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+    bool allIndexingMapsAreProjectedPermutation =
+        llvm::all_of(indexingMaps, [](AffineMap map) {
+          return map.isProjectedPermutation();
+        });
+    if (!allIndexingMapsAreProjectedPermutation) {
+      // TODO: handle non-projected permutations.
+      op->emitOpError()
+          << "Only projected permutation indexing maps are supported.";
+      return failure();
+    }
+
+    SmallVector<utils::IteratorType> loopIteratorTypes =
+        linalgOp.getIteratorTypesArray();
+    ShardingArray meshAxisAssignmentForLoopIterators =
+        getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
+                                              loopIteratorTypes, indexingMaps);
+    if (mesh::isAtLeastOneReductionIteratorSharded(
+            loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+      ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
+      spmdizeLinalgOpWithShardedReduction(
+          linalgOp, spmdizedOperands, operandShardings, resultShardings,
+          loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
+          symbolTable, implicitLocBuilder);
+    } else {
+      spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
+                                         operandShardings, resultShardings,
+                                         spmdizationMap, symbolTable, builder);
+    }
+
+    return success();
+  }
+};
+
+} // namespace
+
+template <typename OpType>
+static void registerOne(MLIRContext *ctx) {
+  OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerAll(MLIRContext *ctx) {
+  (registerOne<OpTypes>(ctx), ...);
+}
+
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
+    DialectRegistry registry;
+    registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
+                    tensor::TensorDialect>();
+    ctx->appendDialectRegistry(registry);
+    for (StringRef name : registry.getDialectNames())
+      ctx->getOrLoadDialect(name);
+
+    registerOne<linalg::GenericOp>(ctx);
+    registerAll<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+        >(ctx);
+  });
+}
+
+} // namespace mlir::linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 8b3119f02e8fda..bd870d4f982e5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -275,14 +275,6 @@ struct LinalgOpPartialReductionInterface
     ArrayRef<int64_t> oldShape =
         linalgOp.getShape(linalgOp.getDpsInitOperand(0));
 
-    // Extend tile size vector to the rank of the output tensor.
-    SmallVector<Value> tileSizeVector =
-        getValueOrCreateConstantIndexOp(b, loc, sizes);
-    if (tileSizeVector.size() < oldShape.size()) {
-      auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
-      tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero);
-    }
-
     // Calculate the new shape, we insert the new dimensions based on the index
     // of the reduction dimensions.
     SmallVector<int64_t> newOutputShape;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 838255cf5a5ba3..34282fab855792 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -620,6 +620,13 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        Value input, StringRef mesh,
+                        ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
+  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
+        reduction);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_slice op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index fe3d7c44413fef..50e461a6d50927 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -563,6 +563,85 @@ void mesh::spmdizeFullyReplicatedOperation(
   builder.clone(op, spmdizationMap);
 }
 
+static void updateMeshAxisAssignmentForLoopIterators(
+    ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+    SmallVector<std::optional<SmallVector<MeshAxis>>>
+        &meshAxesAssignmentForLoopIterators) {
+  AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
+  unsigned loopIteratorIdx = affineDimExpr.getPosition();
+  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+    assert(llvm::equal(meshAxesAssignmentForTensorAxis,
+                       *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+  } else {
+    meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
+        llvm::to_vector(meshAxesAssignmentForTensorAxis);
+  }
+}
+
+ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
+    ArrayRef<MeshShardingAttr> operandShardings,
+    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<AffineMap> indexingMaps) {
+  SmallVector<std::optional<SmallVector<MeshAxis>>>
+      meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+  SmallVector<MeshShardingAttr> operatorAndResultShardings;
+  operatorAndResultShardings.reserve(operandShardings.size() +
+                                     resultShardings.size());
+  operatorAndResultShardings.insert(operatorAndResultShardings.end(),
+                                    operandShardings.begin(),
+                                    operandShardings.end());
+  for (auto [sharding, affineMap] :
+       llvm::zip(operatorAndResultShardings, indexingMaps)) {
+    if (!sharding) {
+      continue;
+    }
+    for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+         llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
+      updateMeshAxisAssignmentForLoopIterators(
+          meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+          meshAxisAssignmentForLoopIterators);
+    }
+  }
+
+  ShardingArray res;
+  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
+                  [](std::optional<SmallVector<MeshAxis>> &axes) {
+                    if (!axes) {
+                      return SmallVector<MeshAxis>();
+                    };
+                    return std::move(*axes);
+                  });
+  return res;
+}
+
+bool mesh::isAtLeastOneReductionIteratorSharded(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+  for (auto [loopIteratorType, meshAxisAssignment] :
+       llvm::zip(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+    if (loopIteratorType == utils::IteratorType::reduction &&
+        !meshAxisAssignment.empty()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+    ArrayRef<utils::IteratorType> loopIteratorTypes,
+    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+  SmallVector<MeshAxis> meshAxes;
+  for (auto [loopIteratorType, meshAxisAssignment] :
+       llvm::zip(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+    if (loopIteratorType == utils::IteratorType::reduction) {
+      meshAxes.insert(meshAxes.end(), meshAxisAssignment.begin(),
+                      meshAxisAssignment.end());
+    }
+  }
+  return meshAxes;
+}
+
 void mesh::spmdizeTriviallyShardableOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
     ArrayRef<MeshShardingAttr> operandShardings,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index d59b9119dea541..cb13ee404751ca 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -208,4 +208,17 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
       .cast<TypedValue<IndexType>>();
 }
 
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+                                               ArrayRef<MeshAxis> meshAxes,
+                                               ImplicitLocOpBuilder &builder) {
+  ResultRange processInGroupMultiIndex =
+      builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
+  Operation::result_range processGroupShape =
+      builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
+  OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
+      llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+      llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+  return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
+}
+
 } // namespace mlir::mesh
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
new file mode 100644
index 00000000000000..6d21def8de2753
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt \
+// RUN:  --mesh-spmdization \
+// RUN:  --test-constant-fold \
+// RUN:  --split-input-file \
+// RUN:  %s | FileCheck %s
+
+// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)>
+#map_identity_1d = affine_map<(d0) -> (d0)>
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor
+func.func @elementwise_static_1d_mesh_static_1d_tensor(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>,
+  %in1: tensor<2xi8>,
+  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>,
+  %in2: tensor<2xi8>,
+  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1xi8>
+  %dps_out: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: %[[RES:.*]] = linalg.generic {
+  // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
+  // CHECK-SAME: iterator_types = ["parallel"]}
+  // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1xi8>, tensor<1xi8>)
+  // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1xi8>) {
+  %res = linalg.generic {
+      indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d],
+      iterator_types = ["parallel"]
+    } ins(%in1_shared2, %in2_shared2 : tensor<2xi8>, tensor<2xi8>)
+      outs(%dps_out_shared2 : tensor<2xi8>) {
+    ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8):
+      %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
+      linalg.yield %res_scalar : i8
+    } -> tensor<2xi8>
+  %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<2xi8>
+  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  // CHECK: return %[[RES]] : tensor<1xi8>
+  return %res_shared2 : tensor<2xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 4)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding
+func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>,
+  %in1: tensor<4x3xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>,
+  %in2: tensor<3x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1x8xi8>
+  %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<1x8xi8> {
+) -> tensor<4x8xi8> {
+  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<4x3xi8>
+  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x3xi8>
+  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[]]> : tensor<3x8xi8>
+  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<3x8xi8>
+  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<4x8xi8>
+  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+  // CHECK: %[[RES:.*]] = linalg.matmul
+  // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
+  // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
+  // CHECK-SAME: -> tensor<1x8xi8>
+  %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
+      outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+  %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<4x8xi8>
+  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+  // CHECK: return %[[RES]] : tensor<1x8xi8>
+  return %res_shared2 : tensor<4x8xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 3)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding
+func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+  %in1: tensor<4x6xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
+  %in2: tensor<6x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
+  %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<4x8xi8> {
+) -> tensor<4x8xi8> {
+  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
+  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
+  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
+  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
+  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
+  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+  // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:  %[[C0_I8:.*]] = arith.constant 0 : i8
+  // CHECK-DAG:  %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+  // CHECK-DAG:  %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+  // CHECK:      %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
+  // CHECK:      %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
+  // CHECK:        scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
+  // CHECK:      } else {
+  // CHECK-DAG:    %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
+  // CHECK:        %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
+  // CHECK-SAME:       outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+  // CHECK:        scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
+  // CHECK:      }
+  // CHECK:      %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
+  // CHECK-SAME:     outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+  // CHECK:      %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
+  %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
+      outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+  %res_shared1 = mesh.shard %res to <@mesh_1d, [[]]> : tensor<4x8xi8>
+  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+  // CHECK:      return %[[ALL_REDUCED]] : tensor<4x8xi8>
+  return %res_shared2 : tensor<4x8xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 3)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result
+func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+  %in1: tensor<4x6xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
+  %in2: tensor<6x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
+  %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<4x8xi8> {
+) -> tensor<4x8xi8> {
+  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
+  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
+  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
+  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
+  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
+  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+  // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:  %[[C0_I8:.*]] = arith.constant 0 : i8
+  // CHECK-DAG:  %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+  // CHECK-DAG:  %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+  // CHECK:      %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
+  // CHECK:      %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
+  // CHECK:        scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
+  // CHECK:      } else {
+  // CHECK-DAG:    %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
+  // CHECK:        %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
+  // CHECK-SAME:       outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+  // CHECK:        scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
+  // CHECK:      }
+  // CHECK:      %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
+  // CHECK-SAME:     outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+  %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
+      outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+  %res_shared1 = mesh.shard %res to <@mesh_1d, [[]], partial = sum[0]> : tensor<4x8xi8>
+  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]], partial = sum[0]> annotate_for_users: tensor<4x8xi8>
+  // CHECK:      return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
+  return %res_shared2 : tensor<4x8xi8>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a21bc01aa1e3ca..a113365cbc47cf 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11002,10 +11002,13 @@ cc_library(
         ":MathDialect",
         ":MemRefDialect",
         ":MemRefTransforms",
+        ":MeshDialect",
+        ":MeshTransforms",
         ":Pass",
         ":SCFDialect",
         ":SCFTransforms",
         ":SCFUtils",
+        ":MeshShardingInterface",
         ":SparseTensorDialect",
         ":SubsetOpInterface",
         ":Support",

>From 1a84fafb817d405034667be6bf4edfd6dc134b74 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 21 Feb 2024 07:19:36 -0800
Subject: [PATCH 2/2] Add sharding interface promise to the Linalg dialect

---
 mlir/include/mlir/IR/Dialect.h                    | 5 +++++
 mlir/lib/Dialect/Linalg/IR/CMakeLists.txt         | 1 +
 mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp      | 7 +++++++
 utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 +
 4 files changed, 14 insertions(+)

diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 50f6f6de5c2897..e43212bba19a52 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -216,6 +216,11 @@ class Dialect {
         {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
   }
 
+  template <typename InterfaceT, typename... ConcreteT>
+  void declarePromisedInterfaces() {
+    (declarePromisedInterface<ConcreteT, InterfaceT>(), ...);
+  }
+
   /// Checks if the given interface, which is attempting to be used, is a
   /// promised interface of this dialect that has yet to be implemented. If so,
   /// emits a fatal error. `interfaceName` is an optional string that contains a
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index f0ac1899bb02ab..c187563b8f0c4e 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRParser
+  MLIRShardingInterface
   MLIRSideEffectInterfaces
   MLIRSparseTensorDialect
   MLIRSCFDialect
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
index 5069d43e7db95f..027058d4de6328 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() {
       >(namedStructuredOpRegionBuilders);
 
   addInterfaces<LinalgInlinerInterface>();
+
+  declarePromisedInterface<GenericOp, mesh::ShardingInterface>();
+  declarePromisedInterfaces<mesh::ShardingInterface,
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+                            >();
 }
 
 LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a113365cbc47cf..3d0ddf775a0f72 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10849,6 +10849,7 @@ cc_library(
         ":MemRefDialect",
         ":Parser",
         ":SCFDialect",
+        ":MeshShardingInterface",
         ":SideEffectInterfaces",
         ":SparseTensorDialect",
         ":Support",



More information about the llvm-commits mailing list