[Mlir-commits] [mlir] [mlir][mesh] Add lowering of process multi-index op (PR #77490)

Boian Petkantchin llvmlistbot at llvm.org
Tue Jan 9 08:05:09 PST 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/77490

>From 70c4071102a1f6ec84648f498b11552409e84864 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 5 Jan 2024 17:48:36 -0800
Subject: [PATCH] [mlir][mesh] Add lowering of process multi-index op

* Rename mesh.process_index -> mesh.process_multi_index.
* Add mesh.process_linear_index op.
* Add lowering of mesh.process_multi_index into an expression using
mesh.process_linear_index, mesh.cluster_shape and affine.delinearize_index.

This is useful to lower mesh ops and prepare them for further lowering where
the runtime may have only the linear index of a device/process.
For example in MPI we have a rank (linear index) in a communicator.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 32 +++++++-
 .../mlir/Dialect/Mesh/Transforms/Transforms.h | 26 ++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 33 ++++++--
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |  2 +
 .../Mesh/Transforms/Simplifications.cpp       |  2 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  4 +-
 .../Dialect/Mesh/Transforms/Transforms.cpp    | 82 +++++++++++++++++++
 mlir/test/Dialect/Mesh/invalid.mlir           | 20 ++---
 mlir/test/Dialect/Mesh/ops.mlir               | 24 +++---
 .../Mesh/process-multi-index-op-lowering.mlir | 23 ++++++
 .../Dialect/Mesh/resharding-spmdization.mlir  |  4 +-
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |  1 +
 .../Mesh/TestProcessMultiIndexOpLowering.cpp  | 55 +++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 14 files changed, 273 insertions(+), 37 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
 create mode 100644 mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
 create mode 100644 mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f459077ea12022..a9068562f5c903 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,7 +96,8 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   let hasVerifier = 1;
 }
 
-def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
+  Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "Get the shape of the cluster.";
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
@@ -209,11 +210,15 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
-def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
-  let summary = "Get the index of current device along specified mesh axis.";
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+]> {
+  let summary = "Get the multi index of current device along specified mesh axes.";
   let description = [{
     It is used in the SPMD format of IR.
     The `axes` mush be non-negative and less than the total number of mesh axes.
+    If the axes are empty then get the index along all axes.
   }];
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
@@ -232,6 +237,27 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
   ];
 }
 
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+]> {
+  let summary = "Get the linear index of the current device.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.process_linear_index on @mesh : index
+    ```
+    if `@mesh` has shape `(10, 20, 30)`, a device with multi
+    index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh);
+  let results = (outs Index:$result);
+  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
new file mode 100644
index 00000000000000..10a965daac71b9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -0,0 +1,26 @@
+//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+
+namespace mlir {
+class RewritePatternSet;
+class SymbolTableCollection;
+class DialectRegistry;
+namespace mesh {
+
+void processMultiIndexOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+
+void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6667d409df8b78..9b110c462915e7 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -250,7 +250,8 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                            ClusterOp mesh) {
   build(odsBuilder, odsState,
         SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
-        mesh.getSymName(), MeshAxesAttr());
+        mesh.getSymName(),
+        MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
 }
 
 void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
@@ -325,11 +326,11 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.process_index op
+// mesh.process_multi_index op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
@@ -348,20 +349,38 @@ ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
-void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                           ClusterOp mesh) {
+void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                                ClusterOp mesh) {
   build(odsBuilder, odsState,
         SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
         mesh.getSymName(), MeshAxesAttr());
 }
 
-void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                           StringRef mesh, ArrayRef<MeshAxis> axes) {
+void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                                StringRef mesh, ArrayRef<MeshAxis> axes) {
   build(odsBuilder, odsState,
         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.process_linear_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return success();
+}
+
+void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
+                                 OperationState &odsState, ClusterOp mesh) {
+  build(odsBuilder, odsState, mesh.getSymName());
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index 7a70c047ec9dce..dccb75848c94f0 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
   Simplifications.cpp
   ShardingPropagation.cpp
   Spmdization.cpp
+  Transforms.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
   MLIRShardingInterface
 
   LINK_LIBS PUBLIC
+  MLIRAffineDialect
   MLIRArithDialect
   MLIRControlFlowDialect
   MLIRFuncDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..c0f081ff8ceff4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -1,4 +1,4 @@
-//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 37b86535959652..0e83c024fc08f8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -206,8 +206,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
 
   Value processIndexAlongAxis =
       builder
-          .create<ProcessIndexOp>(mesh.getSymName(),
-                                  SmallVector<MeshAxis>({splitMeshAxis}))
+          .create<ProcessMultiIndexOp>(mesh.getSymName(),
+                                       SmallVector<MeshAxis>({splitMeshAxis}))
           .getResult()[0];
 
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
new file mode 100644
index 00000000000000..5aef7a2396ccff
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -0,0 +1,82 @@
+//===- Transforms.cpp ---------------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+
+namespace mlir::mesh {
+
+namespace {
+
+/// Lower `mesh.process_multi_index` into expression using
+/// `mesh.process_linear_index` and `mesh.cluster_shape`.
+struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
+  template <typename... OpRewritePatternArgs>
+  ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
+                              OpRewritePatternArgs &&...opRewritePatternArgs)
+      : OpRewritePattern(
+            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+        symbolTableCollection(symbolTableCollection) {}
+
+  LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    ClusterOp mesh =
+        symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+            op.getOperation(), op.getMeshAttr());
+    if (!mesh) {
+      return failure();
+    }
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+    Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
+    ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
+    SmallVector<Value> completeMultiIndex = builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape).getMultiIndex();
+    SmallVector<Value> multiIndex;
+    ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+    SmallVector<MeshAxis> opAxesIota;
+    if (opMeshAxes.empty()) {
+      opAxesIota.resize(mesh.getRank());
+      std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+      opMeshAxes = opAxesIota;
+    }
+    llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
+                    [&completeMultiIndex](MeshAxis meshAxis) {
+                      return completeMultiIndex[meshAxis];
+                    });
+    rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
+    return success();
+  }
+
+private:
+  SymbolTableCollection &symbolTableCollection;
+};
+
+} // namespace
+
+void processMultiIndexOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+  patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
+                                            patterns.getContext());
+}
+
+void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
+  registry.insert<affine::AffineDialect, mesh::MeshDialect>();
+}
+
+} // namespace mlir::mesh
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3e1b04da0dfda9..753ec3ca7d0479 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -128,9 +128,9 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
-func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
+func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
+  %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
   return %0#0, %0#1 : index, index
 }
 
@@ -138,9 +138,9 @@ func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
 
 mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
 
-func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
+func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
+  %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
@@ -148,9 +148,9 @@ func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
-func.func @process_index_wrong_number_of_results() -> (index, index) {
+func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
-  %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
+  %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
   return %0#0, %0#1 : index, index
 }
 
@@ -158,17 +158,17 @@ func.func @process_index_wrong_number_of_results() -> (index, index) {
 
 mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
 
-func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
-  %0:2 = mesh.process_index on @mesh0 : index, index
+  %0:2 = mesh.process_multi_index on @mesh0 : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-func.func @process_index_invalid_mesh_name() -> (index) {
+func.func @process_multi_index_invalid_mesh_name() -> (index) {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
+  %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
   return %0#0 : index
 }
 
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index a7c3b3dbab9c13..b0b4f9f8765b97 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -156,26 +156,26 @@ func.func @cluster_shape_empty_axes() -> (index, index, index) {
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
-// CHECK-LABEL: func @process_index
-func.func @process_index() -> (index, index) {
-  // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
-  %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+// CHECK-LABEL: func @process_multi_index
+func.func @process_multi_index() -> (index, index) {
+  // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
+  %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
   return %0#0, %0#1 : index, index
 }
 
-// CHECK-LABEL: func @process_index_default_axes
-func.func @process_index_default_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
-  %0:3 = mesh.process_index on @mesh0 : index, index, index
+// CHECK-LABEL: func @process_multi_index_default_axes
+func.func @process_multi_index_default_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
-// CHECK-LABEL: func @process_index_empty_axes
-func.func @process_index_empty_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
-  %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+// CHECK-LABEL: func @process_multi_index_empty_axes
+func.func @process_multi_index_empty_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
new file mode 100644
index 00000000000000..9602fb729c2681
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
+
+mesh.cluster @mesh2d(rank = 2)
+
+// CHECK-LABEL: func.func @multi_index_2d_mesh
+func.func @multi_index_2d_mesh() -> (index, index) {
+  // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
+  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  %0:2 = mesh.process_multi_index on @mesh2d : index, index
+  // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
+func.func @multi_index_2d_mesh_single_inner_axis() -> index {
+  // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
+  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
+  // CHECK: return %[[MULTI_IDX]]#0 : index
+  return %0 : index
+}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index c7088fe646d86f..786ea386df815a 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -21,7 +21,7 @@ func.func @split_replicated_tensor_axis(
 ) -> tensor<3x14xf32> {
   // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
-  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
+  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
   // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
   // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
   // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
@@ -43,7 +43,7 @@ func.func @split_replicated_tensor_axis_dynamic(
 ) -> tensor<?x3x?xf32> {
   // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
-  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
+  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
   // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
   // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
   // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..1e758cf899c1bd 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMeshTestSimplifications
+  TestProcessMultiIndexOpLowering.cpp
   TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
new file mode 100644
index 00000000000000..7acbf518970456
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
@@ -0,0 +1,55 @@
+//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMultiIndexOpLoweringPass
+    : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
+
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mesh::MeshDialect>();
+    mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+  }
+  StringRef getArgument() const final {
+    return "test-mesh-process-multi-index-op-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering of mesh.process_multi_index op.";
+  }
+};
+} // namespace
+
+void TestMultiIndexOpLoweringPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  SymbolTableCollection symbolTableCollection;
+  mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+                                                    symbolTableCollection);
+  LogicalResult status =
+      applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+}
+
+namespace mlir {
+namespace test {
+void registerTestMultiIndexOpLoweringPass() {
+  PassRegistration<TestMultiIndexOpLoweringPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 09ff66e07957af..5c6a72881ddf41 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,6 +120,7 @@ void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
+void registerTestMultiIndexOpLoweringPass();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -240,6 +241,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestMultiIndexOpLoweringPass();
   mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();



More information about the Mlir-commits mailing list