[Mlir-commits] [mlir] [mlir][mesh] Add lowering of process multi-index op (PR #77490)
Boian Petkantchin
llvmlistbot at llvm.org
Tue Jan 9 08:05:53 PST 2024
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/77490
>From 786f5a7e3bfed2e8e10e6dc1a451b453f84b189d Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 5 Jan 2024 17:48:36 -0800
Subject: [PATCH] [mlir][mesh] Add lowering of process multi-index op
* Rename mesh.process_index -> mesh.process_multi_index.
* Add mesh.process_linear_index op.
* Add lowering of mesh.process_multi_index into an expression using
mesh.process_linear_index, mesh.cluster_shape and affine.delinearize_index.
This is useful to lower mesh ops and prepare them for further lowering where
the runtime may have only the linear index of a device/process.
For example in MPI we have a rank (linear index) in a communicator.
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 32 ++++++-
.../mlir/Dialect/Mesh/Transforms/Transforms.h | 26 ++++++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 33 ++++++--
.../Dialect/Mesh/Transforms/CMakeLists.txt | 2 +
.../Mesh/Transforms/Simplifications.cpp | 2 +-
.../Dialect/Mesh/Transforms/Spmdization.cpp | 4 +-
.../Dialect/Mesh/Transforms/Transforms.cpp | 84 +++++++++++++++++++
mlir/test/Dialect/Mesh/invalid.mlir | 20 ++---
mlir/test/Dialect/Mesh/ops.mlir | 24 +++---
.../Mesh/process-multi-index-op-lowering.mlir | 23 +++++
.../Dialect/Mesh/resharding-spmdization.mlir | 4 +-
mlir/test/lib/Dialect/Mesh/CMakeLists.txt | 1 +
.../Mesh/TestProcessMultiIndexOpLowering.cpp | 55 ++++++++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
14 files changed, 275 insertions(+), 37 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
create mode 100644 mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
create mode 100644 mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
create mode 100644 mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f459077ea12022..a9068562f5c903 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,7 +96,8 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
-def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
+ Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Get the shape of the cluster.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
@@ -209,11 +210,15 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
-def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
- let summary = "Get the index of current device along specified mesh axis.";
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
+]> {
+ let summary = "Get the multi index of current device along specified mesh axes.";
let description = [{
It is used in the SPMD format of IR.
The `axes` mush be non-negative and less than the total number of mesh axes.
+ If the axes are empty then get the index along all axes.
}];
let arguments = (ins
FlatSymbolRefAttr:$mesh,
@@ -232,6 +237,27 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
];
}
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
+]> {
+ let summary = "Get the linear index of the current device.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.process_linear_index on @mesh : index
+ ```
+ if `@mesh` has shape `(10, 20, 30)`, a device with multi
+ index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh);
+ let results = (outs Index:$result);
+ let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
new file mode 100644
index 00000000000000..10a965daac71b9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -0,0 +1,26 @@
+//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+
+namespace mlir {
+class RewritePatternSet;
+class SymbolTableCollection;
+class DialectRegistry;
+namespace mesh {
+
+void processMultiIndexOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+
+void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6667d409df8b78..9b110c462915e7 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -250,7 +250,8 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
ClusterOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
- mesh.getSymName(), MeshAxesAttr());
+ mesh.getSymName(),
+ MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
}
void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
@@ -325,11 +326,11 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
}
//===----------------------------------------------------------------------===//
-// mesh.process_index op
+// mesh.process_multi_index op
//===----------------------------------------------------------------------===//
LogicalResult
-ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
@@ -348,20 +349,38 @@ ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
-void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- ClusterOp mesh) {
+void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(), MeshAxesAttr());
}
-void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef mesh, ArrayRef<MeshAxis> axes) {
+void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
+//===----------------------------------------------------------------------===//
+// mesh.process_linear_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return success();
+}
+
+void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState, ClusterOp mesh) {
+ build(odsBuilder, odsState, mesh.getSymName());
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index 7a70c047ec9dce..dccb75848c94f0 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
Simplifications.cpp
ShardingPropagation.cpp
Spmdization.cpp
+ Transforms.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRShardingInterface
LINK_LIBS PUBLIC
+ MLIRAffineDialect
MLIRArithDialect
MLIRControlFlowDialect
MLIRFuncDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 643bd7b8e77c93..c0f081ff8ceff4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -1,4 +1,4 @@
-//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 37b86535959652..0e83c024fc08f8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -206,8 +206,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Value processIndexAlongAxis =
builder
- .create<ProcessIndexOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
+ .create<ProcessMultiIndexOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
.getResult()[0];
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
new file mode 100644
index 00000000000000..c27e173d877d69
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -0,0 +1,84 @@
+//===- Transforms.cpp ---------------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <iterator>
+#include <numeric>
+
+namespace mlir::mesh {
+
+namespace {
+
+/// Lower `mesh.process_multi_index` into expression using
+/// `mesh.process_linear_index` and `mesh.cluster_shape`.
+struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
+ template <typename... OpRewritePatternArgs>
+ ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
+ : OpRewritePattern(
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+ symbolTableCollection(symbolTableCollection) {}
+
+ LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
+ PatternRewriter &rewriter) const override {
+ ClusterOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+ if (!mesh) {
+ return failure();
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+ Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
+ ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
+ SmallVector<Value> completeMultiIndex =
+ builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
+ .getMultiIndex();
+ SmallVector<Value> multiIndex;
+ ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
+ SmallVector<MeshAxis> opAxesIota;
+ if (opMeshAxes.empty()) {
+ opAxesIota.resize(mesh.getRank());
+ std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
+ opMeshAxes = opAxesIota;
+ }
+ llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
+ [&completeMultiIndex](MeshAxis meshAxis) {
+ return completeMultiIndex[meshAxis];
+ });
+ rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
+ return success();
+ }
+
+private:
+ SymbolTableCollection &symbolTableCollection;
+};
+
+} // namespace
+
+void processMultiIndexOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+ patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
+ patterns.getContext());
+}
+
+void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry) {
+ registry.insert<affine::AffineDialect, mesh::MeshDialect>();
+}
+
+} // namespace mlir::mesh
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3e1b04da0dfda9..753ec3ca7d0479 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -128,9 +128,9 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
-func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
+func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
+ %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}
@@ -138,9 +138,9 @@ func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
-func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
+func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
// expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
+ %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
@@ -148,9 +148,9 @@ func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
-func.func @process_index_wrong_number_of_results() -> (index, index) {
+func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
- %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
+ %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}
@@ -158,17 +158,17 @@ func.func @process_index_wrong_number_of_results() -> (index, index) {
mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
-func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
- %0:2 = mesh.process_index on @mesh0 : index, index
+ %0:2 = mesh.process_multi_index on @mesh0 : index, index
return %0#0, %0#1 : index, index
}
// -----
-func.func @process_index_invalid_mesh_name() -> (index) {
+func.func @process_multi_index_invalid_mesh_name() -> (index) {
// expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
+ %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
return %0#0 : index
}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index a7c3b3dbab9c13..b0b4f9f8765b97 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -156,26 +156,26 @@ func.func @cluster_shape_empty_axes() -> (index, index, index) {
return %0#0, %0#1, %0#2 : index, index, index
}
-// CHECK-LABEL: func @process_index
-func.func @process_index() -> (index, index) {
- // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
- %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+// CHECK-LABEL: func @process_multi_index
+func.func @process_multi_index() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
-// CHECK-LABEL: func @process_index_default_axes
-func.func @process_index_default_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
- %0:3 = mesh.process_index on @mesh0 : index, index, index
+// CHECK-LABEL: func @process_multi_index_default_axes
+func.func @process_multi_index_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
-// CHECK-LABEL: func @process_index_empty_axes
-func.func @process_index_empty_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
- %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+// CHECK-LABEL: func @process_multi_index_empty_axes
+func.func @process_multi_index_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
new file mode 100644
index 00000000000000..9602fb729c2681
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
+
+mesh.cluster @mesh2d(rank = 2)
+
+// CHECK-LABEL: func.func @multi_index_2d_mesh
+func.func @multi_index_2d_mesh() -> (index, index) {
+ // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
+ // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ %0:2 = mesh.process_multi_index on @mesh2d : index, index
+ // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
+func.func @multi_index_2d_mesh_single_inner_axis() -> index {
+ // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
+ // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
+ // CHECK: return %[[MULTI_IDX]]#0 : index
+ return %0 : index
+}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index c7088fe646d86f..786ea386df815a 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -21,7 +21,7 @@ func.func @split_replicated_tensor_axis(
) -> tensor<3x14xf32> {
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
- // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
// CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
// CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
@@ -43,7 +43,7 @@ func.func @split_replicated_tensor_axis_dynamic(
) -> tensor<?x3x?xf32> {
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
- // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
// CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index f14d282857a1e0..1e758cf899c1bd 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTestSimplifications
+ TestProcessMultiIndexOpLowering.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
new file mode 100644
index 00000000000000..7acbf518970456
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
@@ -0,0 +1,55 @@
+//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMultiIndexOpLoweringPass
+ : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
+
+ void runOnOperation() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mesh::MeshDialect>();
+ mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+ }
+ StringRef getArgument() const final {
+ return "test-mesh-process-multi-index-op-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering of mesh.process_multi_index op.";
+ }
+};
+} // namespace
+
+void TestMultiIndexOpLoweringPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTableCollection;
+ mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+ symbolTableCollection);
+ LogicalResult status =
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+}
+
+namespace mlir {
+namespace test {
+void registerTestMultiIndexOpLoweringPass() {
+ PassRegistration<TestMultiIndexOpLoweringPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 09ff66e07957af..5c6a72881ddf41 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,6 +120,7 @@ void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
+void registerTestMultiIndexOpLoweringPass();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@@ -240,6 +241,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
+ mlir::test::registerTestMultiIndexOpLoweringPass();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
More information about the Mlir-commits
mailing list