[Mlir-commits] [mlir] dd16cd7 - [mlir][gpu] Add a pattern for transforming gpu.global_id to thread + blockId * blockDim
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 25 13:25:40 PDT 2023
Author: Fabian Mora
Date: 2023-05-25T20:24:38Z
New Revision: dd16cd731dfb4746a351380edc848199cf9631e8
URL: https://github.com/llvm/llvm-project/commit/dd16cd731dfb4746a351380edc848199cf9631e8
DIFF: https://github.com/llvm/llvm-project/commit/dd16cd731dfb4746a351380edc848199cf9631e8.diff
LOG: [mlir][gpu] Add a pattern for transforming gpu.global_id to thread + blockId * blockDim
This patch implements a rewrite pattern for transforming gpu.global_id x
to gpu.thread_id + gpu.block_id * gpu.block_dim.
Reviewed By: makslevental
Differential Revision: https://reviews.llvm.org/D148978
Added:
mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
mlir/test/Dialect/GPU/globalId-rewrite.mlir
Modified:
mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
mlir/lib/Dialect/GPU/CMakeLists.txt
mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index a74db79bcefa2..5e2ff6d646ce7 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -53,12 +53,16 @@ std::unique_ptr<OperationPass<func::FuncOp>> createGpuAsyncRegionPass();
/// mapped to sequential loops.
std::unique_ptr<OperationPass<func::FuncOp>> createGpuMapParallelLoopsPass();
+/// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
+void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
+
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
void populateGpuAllReducePatterns(RewritePatternSet &patterns);
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
+ populateGpuGlobalIdPatterns(patterns);
}
namespace gpu {
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 844b9ac619484..2211e15a5d4b3 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -46,6 +46,7 @@ add_mlir_dialect_library(MLIRGPUDialect
add_mlir_dialect_library(MLIRGPUTransforms
Transforms/AllReduceLowering.cpp
Transforms/AsyncRegionRewriter.cpp
+ Transforms/GlobalIdRewriter.cpp
Transforms/KernelOutlining.cpp
Transforms/MemoryPromotion.cpp
Transforms/ParallelLoopMapper.cpp
@@ -75,6 +76,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
MLIRExecutionEngineUtils
MLIRGPUDialect
MLIRIR
+ MLIRIndexDialect
MLIRLLVMDialect
MLIRGPUToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
diff --git a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
new file mode 100644
index 0000000000000..0c730df73b519
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
@@ -0,0 +1,45 @@
+//===- GlobalIdRewriter.cpp - Implementation of GlobalId rewriting -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements in-dialect rewriting of the global_id op for archs
+// where global_id.x = threadId.x + blockId.x * blockDim.x
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {
+ using OpRewritePattern<gpu::GlobalIdOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::GlobalIdOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto dim = op.getDimension();
+ auto blockId = rewriter.create<gpu::BlockIdOp>(loc, dim);
+ auto blockDim = rewriter.create<gpu::BlockDimOp>(loc, dim);
+ // Compute blockId.x * blockDim.x
+ auto tmp = rewriter.create<index::MulOp>(op.getLoc(), blockId, blockDim);
+ auto threadId = rewriter.create<gpu::ThreadIdOp>(loc, dim);
+ // Compute threadId.x + blockId.x * blockDim.x
+ rewriter.replaceOpWithNewOp<index::AddOp>(op, threadId, tmp);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateGpuGlobalIdPatterns(RewritePatternSet &patterns) {
+ patterns.add<GpuGlobalIdRewriter>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/GPU/globalId-rewrite.mlir b/mlir/test/Dialect/GPU/globalId-rewrite.mlir
new file mode 100644
index 0000000000000..9e02d69daa436
--- /dev/null
+++ b/mlir/test/Dialect/GPU/globalId-rewrite.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
+
+module {
+ // CHECK-LABEL: func.func @globalId
+ // CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
+ func.func @globalId(%sz : index, %mem: memref<index, 1>) {
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+ threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+ // CHECK: %[[BIDY:.*]] = gpu.block_id x
+ // CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim x
+ // CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]]
+ // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
+ // CHECK-NEXT: %[[GIDX:.*]] = index.add %[[TIDX]], %[[TMPY]]
+ %idx = gpu.global_id x
+ // CHECK: memref.store %[[GIDX]], %[[MEM]][] : memref<index, 1>
+ memref.store %idx, %mem[] : memref<index, 1>
+
+ // CHECK: %[[BIDY:.*]] = gpu.block_id y
+ // CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim y
+ // CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]]
+ // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
+ // CHECK-NEXT: %[[GIDY:.*]] = index.add %[[TIDY]], %[[TMPY]]
+ %idy = gpu.global_id y
+ // CHECK: memref.store %[[GIDY]], %[[MEM]][] : memref<index, 1>
+ memref.store %idy, %mem[] : memref<index, 1>
+
+ // CHECK: %[[BIDZ:.*]] = gpu.block_id z
+ // CHECK-NEXT: %[[BDIMZ:.*]] = gpu.block_dim z
+ // CHECK-NEXT: %[[TMPZ:.*]] = index.mul %[[BIDZ]], %[[BDIMZ]]
+ // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
+ // CHECK-NEXT: %[[GIDZ:.*]] = index.add %[[TIDZ]], %[[TMPZ]]
+ %idz = gpu.global_id z
+ // CHECK: memref.store %[[GIDZ]], %[[MEM]][] : memref<index, 1>
+ memref.store %idz, %mem[] : memref<index, 1>
+ gpu.terminator
+ }
+ return
+ }
+}
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 909b62d4097c7..db65f3bccec52 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -25,7 +26,7 @@ struct TestGpuRewritePass
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuRewritePass)
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, func::FuncDialect,
+ registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
memref::MemRefDialect>();
}
StringRef getArgument() const final { return "test-gpu-rewrite"; }
More information about the Mlir-commits
mailing list