[Mlir-commits] [mlir] [mlir] Let GPU ID bounds work on any FunctionOpInterfaces (PR #95166)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Jun 11 12:43:13 PDT 2024


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/95166

This change removes the requirement that the known block or grid IDs be stored on a gpu.func, but instead allows them on any function implementing the FunctionOpInterface. This allows for, for instance, non-kernel functions that live ina func.func or for downstream usecases that don't use gpu.func.

>From de2679a2c84cd7be93fb6f81578cd227f9b1c040 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 11 Jun 2024 19:37:34 +0000
Subject: [PATCH] Let GPU ID bounds work on any FunctionOpInterfaces

This change removes the requirement that the known block or grid IDs
be stored on a gpu.func, but instead allows them on any function
implementing the FunctionOpInterface. This allows for, for instance,
non-kernel functions that live ina func.func or for downstream usecases
that don't use gpu.func.
---
 .../GPUCommon/IndexIntrinsicsOpLowering.h     |  6 +---
 .../GPU/IR/InferIntRangeInterfaceImpls.cpp    | 22 +++++++++++--
 .../test/Dialect/GPU/int-range-interface.mlir | 33 +++++++++++++++++++
 3 files changed, 53 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index d067c70a90ea4..0f74768207205 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -57,11 +57,7 @@ struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern<Op> {
       break;
     }
 
-    Operation *function;
-    if (auto gpuFunc = op->template getParentOfType<gpu::GPUFuncOp>())
-      function = gpuFunc;
-    if (auto llvmFunc = op->template getParentOfType<LLVM::LLVMFuncOp>())
-      function = llvmFunc;
+    Operation *function = op->template getParentOfType<FunctionOpInterface>();
     if (!boundsAttrName.empty() && function) {
       if (auto attr = function->template getAttrOfType<DenseI32ArrayAttr>(
               boundsAttrName)) {
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 69017efb9a0e6..152884e23b929 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -54,6 +55,17 @@ static Value valueByDim(KernelDim3 dims, Dimension dim) {
 
 static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
 
+static std::optional<uint32_t> getKnownLaunchAttr(FunctionOpInterface func,
+                                                  StringRef attrName,
+                                                  Dimension dim) {
+  auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
+  if (!bounds)
+    return std::nullopt;
+  if (bounds.size() < static_cast<uint32_t>(dim))
+    return std::nullopt;
+  return bounds[static_cast<uint32_t>(dim)];
+}
+
 template <typename Op>
 static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
   Dimension dim = op.getDimension();
@@ -73,12 +85,16 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
       return value.getZExtValue();
   }
 
-  if (auto func = op->template getParentOfType<GPUFuncOp>()) {
+  if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
     switch (type) {
     case LaunchDims::Block:
-      return llvm::transformOptional(func.getKnownBlockSize(dim), zext);
+      return llvm::transformOptional(
+          getKnownLaunchAttr(func, GPUFuncOp::getKnownBlockSizeAttrName(), dim),
+          zext);
     case LaunchDims::Grid:
-      return llvm::transformOptional(func.getKnownGridSize(dim), zext);
+      return llvm::transformOptional(
+          getKnownLaunchAttr(func, GPUFuncOp::getKnownGridSizeAttrName(), dim),
+          zext);
     }
   }
   return std::nullopt;
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index a0917a2fdf110..a6c74fec6e824 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -215,3 +215,36 @@ module attributes {gpu.container_module} {
   }
 }
 
+// -----
+
+// CHECK-LABEL: func @annotated_kernel
+module {
+  func.func @annotated_kernel()
+    attributes {gpu.known_block_size = array<i32: 8, 12, 16>,
+        gpu.known_grid_size = array<i32: 20, 24, 28>} {
+
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %block_id_z = gpu.block_id z
+
+    // CHECK: test.reflect_bounds {smax = 19 : index, smin = 0 : index, umax = 19 : index, umin = 0 : index}
+    // CHECK: test.reflect_bounds {smax = 23 : index, smin = 0 : index, umax = 23 : index, umin = 0 : index}
+    // CHECK: test.reflect_bounds {smax = 27 : index, smin = 0 : index, umax = 27 : index, umin = 0 : index}
+    %block_id_x0 = test.reflect_bounds %block_id_x : index
+    %block_id_y0 = test.reflect_bounds %block_id_y : index
+    %block_id_z0 = test.reflect_bounds %block_id_z : index
+
+    %thread_id_x = gpu.thread_id x
+    %thread_id_y = gpu.thread_id y
+    %thread_id_z = gpu.thread_id z
+
+    // CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+    // CHECK: test.reflect_bounds {smax = 11 : index, smin = 0 : index, umax = 11 : index, umin = 0 : index}
+    // CHECK: test.reflect_bounds {smax = 15 : index, smin = 0 : index, umax = 15 : index, umin = 0 : index}
+    %thread_id_x0 = test.reflect_bounds %thread_id_x : index
+    %thread_id_y0 = test.reflect_bounds %thread_id_y : index
+    %thread_id_z0 = test.reflect_bounds %thread_id_z : index
+
+    return
+  }
+}



More information about the Mlir-commits mailing list