[Mlir-commits] [mlir] be575c5 - Re-land D139865 "Add known_block_size and known_grid_size to gpu.func"
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Jan 2 08:39:09 PST 2023
Author: Krzysztof Drewniak
Date: 2023-01-02T16:39:00Z
New Revision: be575c5dfc55a2ebac463be97d863a6f2962926a
URL: https://github.com/llvm/llvm-project/commit/be575c5dfc55a2ebac463be97d863a6f2962926a
DIFF: https://github.com/llvm/llvm-project/commit/be575c5dfc55a2ebac463be97d863a6f2962926a.diff
LOG: Re-land D139865 "Add known_block_size and known_grid_size to gpu.func"
This should fix the MSVC warning that caused the previous revert.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D140766
Added:
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
mlir/test/Dialect/GPU/int-range-interface.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/outlining.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index baf9540c8b695..44423078ff924 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -205,6 +205,14 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
coordinate work items. Declarations of GPU functions, i.e. not having the
body region, are not supported.
+ A function may optionally be annotated with the block and/or grid sizes
+ that will be used when it is launched using the `gpu.known_block_size` and
+ `gpu.known_grid_size` attributes, respectively. If set, these attributes must
+ be arrays of three 32-bit integers giving the x, y, and z launch dimensions.
+ Launching a kernel that has these annotations, or that calls a function with
+ these annotations, using a block size or grid size other than what is specified
+ is undefined behavior.
+
Syntax:
```
@@ -311,6 +319,36 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
return "workgroup_attributions";
}
+ static constexpr StringLiteral getKnownBlockSizeAttrName() {
+ return StringLiteral("gpu.known_block_size");
+ }
+
+ static constexpr StringLiteral getKnownGridSizeAttrName() {
+ return StringLiteral("gpu.known_grid_size");
+ }
+
+ /// Returns the block size this kernel will be launched with along
+ /// dimension `dim` if known. The value of gpu.thread_id dim will be strictly
+ /// less than this size.
+ Optional<uint32_t> getKnownBlockSize(gpu::Dimension dim) {
+ if (auto array =
+ (*this)->getAttrOfType<DenseI32ArrayAttr>(getKnownBlockSizeAttrName())) {
+ return array[static_cast<uint32_t>(dim)];
+ }
+ return std::nullopt;
+ }
+
+ /// Returns the grid size this kernel will be launched with along
+ /// dimension `dim` if known. The value of gpu.block_id dim will be strictly
+ /// less than this size.
+ Optional<uint32_t> getKnownGridSize(gpu::Dimension dim) {
+ if (auto array =
+ (*this)->getAttrOfType<DenseI32ArrayAttr>(getKnownGridSizeAttrName())) {
+ return array[static_cast<uint32_t>(dim)];
+ }
+ return std::nullopt;
+ }
+
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
@@ -329,6 +367,8 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
LogicalResult verifyBody();
}];
let hasCustomAssemblyFormat = 1;
+
+ let hasVerifier = 1;
}
def GPU_LaunchFuncOp : GPU_Op<"launch_func",
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e1d92b9eac315..d687043c22f79 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -1057,6 +1058,27 @@ LogicalResult GPUFuncOp::verifyBody() {
return success();
}
+static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
+ StringRef attrName) {
+ auto maybeAttr = op->getAttr(attrName);
+ if (!maybeAttr)
+ return success();
+ auto array = maybeAttr.dyn_cast<DenseI32ArrayAttr>();
+ if (!array)
+ return op.emitOpError(attrName + " must be a dense i32 array");
+ if (array.size() != 3)
+ return op.emitOpError(attrName + " must contain exactly 3 elements");
+ return success();
+}
+
+LogicalResult GPUFuncOp::verify() {
+ if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName())))
+ return failure();
+ if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName())))
+ return failure();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 3df44a29296ba..aacd8d57d99bb 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -7,7 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "llvm/ADT/STLForwardCompat.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MathExtras.h"
+#include <optional>
using namespace mlir;
using namespace mlir::gpu;
@@ -23,40 +28,108 @@ static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
APInt(width, umax));
}
+namespace {
+enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
+} // end namespace
+
+/// If the operation `op` is in a context that is annotated with maximum
+/// launch dimensions (a launch op with constant block or grid
+/// sizes or a launch_func op with the appropriate dimensions), return
+/// the bound on the maximum size of the dimension that the op is querying.
+/// IDs will be one less than this bound.
+
+static Value valueByDim(KernelDim3 dims, Dimension dim) {
+ switch (dim) {
+ case Dimension::x:
+ return dims.x;
+ case Dimension::y:
+ return dims.y;
+ case Dimension::z:
+ return dims.z;
+ }
+ llvm_unreachable("All dimension enum cases handled above");
+}
+
+static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
+
+template <typename Op>
+static Optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
+ Dimension dim = op.getDimension();
+ if (auto launch = op->template getParentOfType<LaunchOp>()) {
+ KernelDim3 bounds;
+ switch (type) {
+ case LaunchDims::Block:
+ bounds = launch.getBlockSizeOperandValues();
+ break;
+ case LaunchDims::Grid:
+ bounds = launch.getGridSizeOperandValues();
+ break;
+ }
+ Value maybeBound = valueByDim(bounds, dim);
+ APInt value;
+ if (matchPattern(maybeBound, m_ConstantInt(&value)))
+ return value.getZExtValue();
+ }
+
+ if (auto func = op->template getParentOfType<GPUFuncOp>()) {
+ switch (type) {
+ case LaunchDims::Block:
+ return llvm::transformOptional(func.getKnownBlockSize(dim), zext);
+ case LaunchDims::Grid:
+ return llvm::transformOptional(func.getKnownGridSize(dim), zext);
+ }
+ }
+ return std::nullopt;
+}
+
void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(1, kMaxDim));
+ Optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Block);
+ if (knownVal)
+ setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
+ else
+ setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
+ uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
+ setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(1, kMaxDim));
+ Optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
+ if (knownVal)
+ setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
+ else
+ setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
+ uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
+ setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1));
+ setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
}
void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
+ setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
}
void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
+ uint64_t blockDimMax =
+ getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
+ uint64_t gridDimMax =
+ getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
setResultRange(getResult(),
- getIndexRange(0, std::numeric_limits<int64_t>::max()));
+ getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
}
void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index fadae79eff85b..e8883ea7c8eb7 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -22,10 +22,12 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
+#include <limits>
namespace mlir {
#define GEN_PASS_DEF_GPULAUNCHSINKINDEXCOMPUTATIONS
@@ -147,8 +149,27 @@ LogicalResult mlir::sinkOperationsIntoLaunchOp(
return success();
}
+/// Return the provided KernelDim3 as an array of i32 constants if possible.
+static DenseI32ArrayAttr maybeConstantDimsAttr(gpu::KernelDim3 dims) {
+ SmallVector<int32_t, 3> constants;
+ MLIRContext *ctx = dims.x.getContext();
+ for (Value v : {dims.x, dims.y, dims.z}) {
+ APInt constValue;
+ if (!matchPattern(v, m_ConstantInt(&constValue)))
+ return nullptr;
+ // In the event someone called for a too-large block or grid dimension,
+ // don't set bounds as it is likely to cause more confusing behavior.
+ if (constValue.ugt(std::numeric_limits<uint32_t>::max()))
+ return nullptr;
+ constants.push_back(
+ constValue.getLimitedValue(std::numeric_limits<uint32_t>::max()));
+ }
+ return DenseI32ArrayAttr::get(ctx, constants);
+}
+
/// Outline the `gpu.launch` operation body into a kernel function. Replace
/// `gpu.terminator` operations by `gpu.return` in the generated function.
+/// Set block and grid size bounds if known.
static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
StringRef kernelFnName,
SetVector<Value> &operands) {
@@ -173,6 +194,19 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type);
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
+
+ // If we can infer bounds on the grid and/or block sizes from the arguments
+ // to the launch op, propagate them to the generated kernel. This is safe
+ // because multiple launches with the same body are not deduplicated.
+ if (auto blockBounds =
+ maybeConstantDimsAttr(launchOp.getBlockSizeOperandValues()))
+ outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName(),
+ blockBounds);
+ if (auto gridBounds =
+ maybeConstantDimsAttr(launchOp.getGridSizeOperandValues()))
+ outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownGridSizeAttrName(),
+ gridBounds);
+
BlockAndValueMapping map;
// Map the arguments corresponding to the launch parameters like blockIdx,
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 2c5af0886e9f5..02aec9dc0476f 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @launch_func
func.func @launch_func(%arg0 : index) {
@@ -41,12 +41,18 @@ func.func @launch_func(%arg0 : index) {
%thread_id_y0 = test.reflect_bounds %thread_id_y
%thread_id_z0 = test.reflect_bounds %thread_id_z
+ // The launch bounds are not constant, and so this can't infer anything
+ // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
+ %thread_id_op = gpu.thread_id y
+ %thread_id_op0 = test.reflect_bounds %thread_id_op
gpu.terminator
}
func.return
}
+// -----
+
// CHECK-LABEL: func @kernel
module attributes {gpu.container_module} {
gpu.module @gpu_module {
@@ -100,9 +106,9 @@ module attributes {gpu.container_module} {
%global_id_y = gpu.global_id y
%global_id_z = gpu.global_id z
- // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index}
- // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index}
- // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
%global_id_x0 = test.reflect_bounds %global_id_x
%global_id_y0 = test.reflect_bounds %global_id_y
%global_id_z0 = test.reflect_bounds %global_id_z
@@ -126,3 +132,86 @@ module attributes {gpu.container_module} {
}
}
+// -----
+
+// CHECK-LABEL: func @annotated_kernel
+module attributes {gpu.container_module} {
+ gpu.module @gpu_module {
+ gpu.func @annotated_kernel() kernel
+ attributes {gpu.known_block_size = array<i32: 8, 12, 16>,
+ gpu.known_grid_size = array<i32: 20, 24, 28>} {
+
+ %grid_dim_x = gpu.grid_dim x
+ %grid_dim_y = gpu.grid_dim y
+ %grid_dim_z = gpu.grid_dim z
+
+ // CHECK: test.reflect_bounds {smax = 20 : index, smin = 20 : index, umax = 20 : index, umin = 20 : index}
+ // CHECK: test.reflect_bounds {smax = 24 : index, smin = 24 : index, umax = 24 : index, umin = 24 : index}
+ // CHECK: test.reflect_bounds {smax = 28 : index, smin = 28 : index, umax = 28 : index, umin = 28 : index}
+ %grid_dim_x0 = test.reflect_bounds %grid_dim_x
+ %grid_dim_y0 = test.reflect_bounds %grid_dim_y
+ %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %block_id_z = gpu.block_id z
+
+ // CHECK: test.reflect_bounds {smax = 19 : index, smin = 0 : index, umax = 19 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 23 : index, smin = 0 : index, umax = 23 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 27 : index, smin = 0 : index, umax = 27 : index, umin = 0 : index}
+ %block_id_x0 = test.reflect_bounds %block_id_x
+ %block_id_y0 = test.reflect_bounds %block_id_y
+ %block_id_z0 = test.reflect_bounds %block_id_z
+
+ %block_dim_x = gpu.block_dim x
+ %block_dim_y = gpu.block_dim y
+ %block_dim_z = gpu.block_dim z
+
+ // CHECK: test.reflect_bounds {smax = 8 : index, smin = 8 : index, umax = 8 : index, umin = 8 : index}
+ // CHECK: test.reflect_bounds {smax = 12 : index, smin = 12 : index, umax = 12 : index, umin = 12 : index}
+ // CHECK: test.reflect_bounds {smax = 16 : index, smin = 16 : index, umax = 16 : index, umin = 16 : index}
+ %block_dim_x0 = test.reflect_bounds %block_dim_x
+ %block_dim_y0 = test.reflect_bounds %block_dim_y
+ %block_dim_z0 = test.reflect_bounds %block_dim_z
+
+ %thread_id_x = gpu.thread_id x
+ %thread_id_y = gpu.thread_id y
+ %thread_id_z = gpu.thread_id z
+
+ // CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 11 : index, smin = 0 : index, umax = 11 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 15 : index, smin = 0 : index, umax = 15 : index, umin = 0 : index}
+ %thread_id_x0 = test.reflect_bounds %thread_id_x
+ %thread_id_y0 = test.reflect_bounds %thread_id_y
+ %thread_id_z0 = test.reflect_bounds %thread_id_z
+
+ %global_id_x = gpu.global_id x
+ %global_id_y = gpu.global_id y
+ %global_id_z = gpu.global_id z
+
+ // CHECK: test.reflect_bounds {smax = 159 : index, smin = 0 : index, umax = 159 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 287 : index, smin = 0 : index, umax = 287 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 447 : index, smin = 0 : index, umax = 447 : index, umin = 0 : index}
+ %global_id_x0 = test.reflect_bounds %global_id_x
+ %global_id_y0 = test.reflect_bounds %global_id_y
+ %global_id_z0 = test.reflect_bounds %global_id_z
+
+ %subgroup_size = gpu.subgroup_size : index
+ %lane_id = gpu.lane_id
+ %num_subgroups = gpu.num_subgroups : index
+ %subgroup_id = gpu.subgroup_id : index
+
+ // CHECK: test.reflect_bounds {smax = 128 : index, smin = 1 : index, umax = 128 : index, umin = 1 : index}
+ // CHECK: test.reflect_bounds {smax = 127 : index, smin = 0 : index, umax = 127 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
+ // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
+ %subgroup_size0 = test.reflect_bounds %subgroup_size
+ %lane_id0 = test.reflect_bounds %lane_id
+ %num_subgroups0 = test.reflect_bounds %num_subgroups
+ %subgroup_id0 = test.reflect_bounds %subgroup_id
+
+ gpu.return
+ }
+ }
+}
+
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 7a11acbc2d239..76a14d353bc4f 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -599,3 +599,25 @@ func.func @alloc() {
%1 = gpu.alloc(%0) : memref<2x?x?xf32, 1>
return
}
+
+// -----
+
+module attributes {gpu.container_module} {
+ gpu.module @kernel {
+ // expected-error at +1 {{'gpu.func' op gpu.known_block_size must be a dense i32 array}}
+ gpu.func @kernel() kernel attributes {gpu.known_block_size = 32 : i32} {
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {gpu.container_module} {
+ gpu.module @kernel {
+ // expected-error at +1 {{'gpu.func' op gpu.known_block_size must contain exactly 3 elements}}
+ gpu.func @kernel() kernel attributes {gpu.known_block_size = array<i32: 2, 1>} {
+ gpu.return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index 5191dcf8fffb8..422e0c154dd47 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -41,6 +41,8 @@ func.func @launch() {
// CHECK-LABEL: gpu.module @launch_kernel
// CHECK-NEXT: gpu.func @launch_kernel
// CHECK-SAME: (%[[KERNEL_ARG0:.*]]: f32, %[[KERNEL_ARG1:.*]]: memref<?xf32, 1>)
+// CHECK-SAME: gpu.known_block_size = array<i32: 20, 24, 28>
+// CHECK-SAME: gpu.known_grid_size = array<i32: 8, 12, 16>
// CHECK-NEXT: %[[BID:.*]] = gpu.block_id x
// CHECK-NEXT: = gpu.block_id y
// CHECK-NEXT: = gpu.block_id z
@@ -291,3 +293,20 @@ func.func @recursive_device_function() {
// CHECK: func @device_function()
// CHECK: func @recursive_device_function()
// CHECK-NOT: func @device_function
+
+// -----
+
+// CHECK-LABEL: @non_constant_launches
+func.func @non_constant_launches(%arg0 : index) {
+ // CHECK-NOT: gpu.known_block_size
+ // CHECK-NOT: gpu.known_grid_size
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %arg0, %grid_y = %arg0,
+ %grid_z = %arg0)
+ threads(%tx, %ty, %tz) in (%block_x = %arg0, %block_y = %arg0,
+ %block_z = %arg0) {
+ gpu.terminator
+ }
+ return
+}
+
+// CHECK-DL-LABEL: gpu.module @non_constant_launches_kernel attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>}
More information about the Mlir-commits
mailing list