[Mlir-commits] [mlir] [mlir][GPU] Implement ValueBoundsOpInterface for GPU ID operations (PR #122190)
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Jan 9 09:36:28 PST 2025
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/122190
>From 464dddffb6e225d641491402310bc3f5f7d8729d Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 8 Jan 2025 23:24:07 +0000
Subject: [PATCH 1/2] [mlir][GPU] Implement ValueBoundsOpInterface for GPU ID
operations
The GPU ID operations already implement InferIntRangeInterface, which
gives constant lower and upper bounds on those IDs when appropriate
metadata is prentent on the operations or in the surrounding context.
This commit uses that existing code to implement the
ValueBoundsOpInterface, which is used when analyzing affine
operations (unlike the integer range interface, which is used for
arithmetic optimization).
It also implements the interface for gpu.launch, where we can use it
to express the constraint that block/grid sizes are equal to their
value from outside the launch op and that the corresponding IDs are
bounded above by that size.
As a consequence, the test pass for this inference is updated to work
on a FunctionOpInterface and not a func.func, creating minor churn in
other tests.
---
.../GPU/IR/ValueBoundsOpInterfaceImpl.h | 19 +++
mlir/include/mlir/InitAllDialects.h | 2 +
mlir/lib/Dialect/GPU/CMakeLists.txt | 3 +-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 5 +
.../GPU/IR/ValueBoundsOpInterfaceImpl.cpp | 114 +++++++++++++
.../value-bounds-op-interface-impl.mlir | 2 +-
.../Affine/value-bounds-reification.mlir | 4 +-
.../Arith/value-bounds-op-interface-impl.mlir | 4 +-
.../GPU/value-bounds-op-interface-impl.mlir | 150 ++++++++++++++++++
.../value-bounds-op-interface-impl.mlir | 2 +-
.../value-bounds-op-interface-impl.mlir | 2 +-
.../SCF/value-bounds-op-interface-impl.mlir | 2 +-
.../value-bounds-op-interface-impl.mlir | 2 +-
.../Dialect/Vector/test-scalable-bounds.mlir | 2 +-
.../value-bounds-op-interface-impl.mlir | 2 +-
.../Dialect/Affine/TestReifyValueBounds.cpp | 8 +-
16 files changed, 308 insertions(+), 15 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h
create mode 100644 mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp
create mode 100644 mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644
index 00000000000000..9a4e159ef76c83
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h
@@ -0,0 +1,19 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GPU_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_GPU_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace gpu {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace gpu
+} // namespace mlir
+#endif // MLIR_DIALECT_GPU_IR_VALUEBOUNDSOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 7fd0432ddce1bb..c102f811cce4b1 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -37,6 +37,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
@@ -164,6 +165,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
cf::registerBufferizableOpInterfaceExternalModels(registry);
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerValueBoundsOpInterfaceExternalModels(registry);
LLVM::registerInlinerInterface(registry);
linalg::registerAllDialectInterfaceImplementations(registry);
linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 1026e9b509332a..013311ec027dae 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRGPUDialect
IR/GPUDialect.cpp
IR/InferIntRangeInterfaceImpls.cpp
+ IR/ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
@@ -40,7 +41,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupReduceLowering.cpp
-
+
OBJECT
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8e36638d6e5453..49209229259a73 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -217,6 +218,10 @@ void GPUDialect::initialize() {
addInterfaces<GPUInlinerInterface>();
declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
TerminatorOp>();
+ declarePromisedInterfaces<
+ ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
+ ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
+ SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
}
static std::string getSparseHandleKeyword(SparseHandleKind kind) {
diff --git a/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..cfca5cb04b1d2e
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -0,0 +1,114 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace {
+/// Implement ValueBoundsOpInterface (which only works on index-typed values,
+/// gathers a set of constraint expressions, and is used for affine analyses)
+/// in terms of InferIntRangeInterface (which works
+/// on arbitrary integer types, creates [min, max] ranges, and is used in for
+/// arithmetic simplification).
+template <typename Op>
+struct GpuIdOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<GpuIdOpInterface<Op>, Op> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto inferrable = cast<InferIntRangeInterface>(op);
+ assert(value == op->getResult(0) &&
+ "inferring for value that isn't the GPU op's result");
+ auto translateConstraint = [&](Value v, const ConstantIntRanges &range) {
+ assert(v == value &&
+ "GPU ID op inferring values for something that's not its result");
+ cstr.bound(v) >= range.smin().getSExtValue();
+ cstr.bound(v) <= range.smax().getSExtValue();
+ };
+ // No arguments, so we don't need to pass in their ranges.
+ inferrable.inferResultRanges({}, translateConstraint);
+ }
+};
+
+struct GpuLaunchOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<GpuLaunchOpInterface,
+ LaunchOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto launchOp = cast<LaunchOp>(op);
+
+ Value sizeArg = nullptr;
+ bool isSize = false;
+ KernelDim3 gridSizeArgs = launchOp.getGridSizeOperandValues();
+ KernelDim3 blockSizeArgs = launchOp.getBlockSizeOperandValues();
+
+ auto match = [&](KernelDim3 bodyArgs, KernelDim3 externalArgs,
+ bool areSizeArgs) {
+ if (value == bodyArgs.x) {
+ sizeArg = externalArgs.x;
+ isSize = areSizeArgs;
+ }
+ if (value == bodyArgs.y) {
+ sizeArg = externalArgs.y;
+ isSize = areSizeArgs;
+ }
+ if (value == bodyArgs.z) {
+ sizeArg = externalArgs.z;
+ isSize = areSizeArgs;
+ }
+ };
+ match(launchOp.getThreadIds(), blockSizeArgs, false);
+ match(launchOp.getBlockSize(), blockSizeArgs, true);
+ match(launchOp.getBlockIds(), gridSizeArgs, false);
+ match(launchOp.getGridSize(), gridSizeArgs, true);
+ if (launchOp.hasClusterSize()) {
+ KernelDim3 clusterSizeArgs = *launchOp.getClusterSizeOperandValues();
+ match(*launchOp.getClusterIds(), clusterSizeArgs, false);
+ match(*launchOp.getClusterSize(), clusterSizeArgs, true);
+ }
+
+ if (!sizeArg)
+ return;
+ if (isSize) {
+ cstr.bound(value) == cstr.getExpr(sizeArg);
+ cstr.bound(value) >= 1;
+ } else {
+ cstr.bound(value) < cstr.getExpr(sizeArg);
+ cstr.bound(value) >= 0;
+ }
+ }
+};
+} // namespace
+
+void mlir::gpu::registerValueBoundsOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, GPUDialect *dialect) {
+#define REGISTER(X) X::attachInterface<GpuIdOpInterface<X>>(*ctx);
+ REGISTER(ClusterDimOp)
+ REGISTER(ClusterDimBlocksOp)
+ REGISTER(ClusterIdOp)
+ REGISTER(ClusterBlockIdOp)
+ REGISTER(BlockDimOp)
+ REGISTER(BlockIdOp)
+ REGISTER(GridDimOp)
+ REGISTER(ThreadIdOp)
+ REGISTER(LaneIdOp)
+ REGISTER(SubgroupIdOp)
+ REGISTER(GlobalIdOp)
+ REGISTER(NumSubgroupsOp)
+ REGISTER(SubgroupSizeOp)
+#undef REGISTER
+
+ LaunchOp::attachInterface<GpuLaunchOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 5354eb38d7b039..a4310b91a37b3d 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -verify-diagnostics \
// RUN: -split-input-file | FileCheck %s
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 75622f59af83be..817614be505332 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds{reify-to-func-args}))' \
// RUN: -verify-diagnostics -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args use-arith-ops" \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds{reify-to-func-args use-arith-ops}))' \
// RUN: -verify-diagnostics -split-input-file | FileCheck %s --check-prefix=CHECK-ARITH
// CHECK-LABEL: func @reify_through_chain(
diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
index 8fb3ba1a1eccef..a2653d4750ec8b 100644
--- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -verify-diagnostics \
// RUN: -verify-diagnostics -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-affine-reify-value-bounds="use-arith-ops" \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds{use-arith-ops}))' \
// RUN: -verify-diagnostics -split-input-file | \
// RUN: FileCheck %s --check-prefix=CHECK-ARITH
diff --git a/mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir
new file mode 100644
index 00000000000000..719d8a94c982ba
--- /dev/null
+++ b/mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir
@@ -0,0 +1,150 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module( \
+// RUN: func.func(test-affine-reify-value-bounds), \
+// RUN: gpu.module(llvm.func(test-affine-reify-value-bounds)), \
+// RUN: gpu.module(gpu.func(test-affine-reify-value-bounds)))' \
+// RUN: -verify-diagnostics \
+// RUN: -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @launch_func
+func.func @launch_func(%arg0 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ gpu.launch blocks(%block_id_x, %block_id_y, %block_id_z) in (%grid_dim_x = %arg0, %grid_dim_y = %c4, %grid_dim_z = %c2)
+ threads(%thread_id_x, %thread_id_y, %thread_id_z) in (%block_dim_x = %c64, %block_dim_y = %c4, %block_dim_z = %c2) {
+
+ // Sanity checks:
+ // expected-error @below{{unknown}}
+ "test.compare" (%thread_id_x, %c1) {cmp = "EQ"} : (index, index) -> ()
+ // expected-remark @below{{false}}
+ "test.compare" (%thread_id_x, %c64) {cmp = "GE"} : (index, index) -> ()
+
+ // expected-remark @below{{true}}
+ "test.compare" (%grid_dim_x, %c1) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare" (%grid_dim_x, %arg0) {cmp = "EQ"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare" (%grid_dim_y, %c4) {cmp = "EQ"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare" (%grid_dim_z, %c2) {cmp = "EQ"} : (index, index) -> ()
+
+ // expected-remark @below{{true}}
+ "test.compare"(%block_id_x, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%block_id_x, %arg0) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%block_id_y, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%block_id_y, %c4) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%block_id_z, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%block_id_z, %c2) {cmp = "LT"} : (index, index) -> ()
+
+ // expected-remark @below{{true}}
+ "test.compare" (%block_dim_x, %c64) {cmp = "EQ"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare" (%block_dim_y, %c4) {cmp = "EQ"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare" (%block_dim_z, %c2) {cmp = "EQ"} : (index, index) -> ()
+
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %c64) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_y, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_y, %c4) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_z, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_z, %c2) {cmp = "LT"} : (index, index) -> ()
+ gpu.terminator
+ }
+
+ func.return
+}
+
+// -----
+
+// The tests for what the ranges are are located in int-range-interface.mlir,
+// so here we just make sure that the results of that interface propagate into
+// constraints.
+
+// CHECK-LABEL: func @kernel
+module attributes {gpu.container_module} {
+ gpu.module @gpu_module {
+ llvm.func @kernel() attributes {gpu.kernel} {
+
+ %c0 = arith.constant 0 : index
+ %ctid_max = arith.constant 4294967295 : index
+ %thread_id_x = gpu.thread_id x
+
+ // expected-remark @below{{true}}
+ "test.compare" (%thread_id_x, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare" (%thread_id_x, %ctid_max) {cmp = "LT"} : (index, index) -> ()
+ llvm.return
+ }
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @annotated_kernel
+module attributes {gpu.container_module} {
+ gpu.module @gpu_module {
+ gpu.func @annotated_kernel() kernel
+ attributes {known_block_size = array<i32: 8, 12, 16>,
+ known_grid_size = array<i32: 20, 24, 28>} {
+
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %thread_id_x = gpu.thread_id x
+
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %c8) {cmp = "LT"} : (index, index) -> ()
+
+ %block_dim_x = gpu.block_dim x
+ // expected-remark @below{{true}}
+ "test.compare"(%block_dim_x, %c8) {cmp = "EQ"} : (index, index) -> ()
+
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @local_bounds_kernel
+module attributes {gpu.container_module} {
+ gpu.module @gpu_module {
+ gpu.func @local_bounds_kernel() kernel {
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+
+ %block_dim_x = gpu.block_dim x upper_bound 8
+ // expected-remark @below{{true}}
+ "test.compare"(%block_dim_x, %c1) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%block_dim_x, %c8) {cmp = "LE"} : (index, index) -> ()
+ // expected-error @below{{unknown}}
+ "test.compare"(%block_dim_x, %c8) {cmp = "EQ"} : (index, index) -> ()
+
+ %thread_id_x = gpu.thread_id x upper_bound 8
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %c0) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %c8) {cmp = "LT"} : (index, index) -> ()
+
+ gpu.return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir
index 189c8e649ba5e2..bcd330443cc449 100644
--- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -verify-diagnostics \
// RUN: -split-input-file | FileCheck %s
// CHECK-LABEL: func @linalg_fill(
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index 0e0f216b05d489..dc311c6b59ea47 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -verify-diagnostics \
// RUN: -split-input-file | FileCheck %s
// CHECK-LABEL: func @memref_alloc(
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index 65e1017e62c1a4..6e0c16a9a2b33f 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds{reify-to-func-args}))' \
// RUN: -verify-diagnostics -split-input-file | FileCheck %s
// CHECK-LABEL: func @scf_for(
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 0ba9983723a0a1..c0f64d3c843619 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -verify-diagnostics \
// RUN: -split-input-file | FileCheck %s
func.func @unknown_op() -> index {
diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
index 6af904beb660b5..d264d01f445ff3 100644
--- a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -cse -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -cse -verify-diagnostics \
// RUN: -verify-diagnostics -split-input-file | FileCheck %s
#map_dim_i = affine_map<(d0)[s0] -> (-d0 + 32400, s0)>
diff --git a/mlir/test/Dialect/Vector/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Vector/value-bounds-op-interface-impl.mlir
index c04c82970f9c0a..1a94bbac9dff85 100644
--- a/mlir/test/Dialect/Vector/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Vector/value-bounds-op-interface-impl.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -verify-diagnostics \
// RUN: -split-input-file | FileCheck %s
// CHECK-LABEL: func @vector_transfer_write(
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 34513cd418e4c2..2c954ffc4acef9 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Pass/Pass.h"
@@ -30,7 +31,8 @@ namespace {
/// This pass applies the permutation on the first maximal perfect nest.
struct TestReifyValueBounds
- : public PassWrapper<TestReifyValueBounds, OperationPass<func::FuncOp>> {
+ : public PassWrapper<TestReifyValueBounds,
+ InterfacePass<FunctionOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
StringRef getArgument() const final { return PASS_NAME; }
@@ -74,7 +76,7 @@ invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
/// Look for "test.reify_bound" ops in the input and replace their results with
/// the reified values.
-static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
+static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
bool reifyToFuncArgs,
bool useArithOps) {
IRRewriter rewriter(funcOp.getContext());
@@ -156,7 +158,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
}
/// Look for "test.compare" ops and emit errors/remarks.
-static LogicalResult testEquality(func::FuncOp funcOp) {
+static LogicalResult testEquality(FunctionOpInterface funcOp) {
IRRewriter rewriter(funcOp.getContext());
WalkResult result = funcOp.walk([&](test::CompareOp op) {
auto cmpType = op.getComparisonOperator();
>From 800aa8f67b86d113909caa0edab64d915b490ff1 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 9 Jan 2025 17:36:14 +0000
Subject: [PATCH 2/2] Reviwe feedback, fix vector test
---
mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp | 2 +-
.../test/Dialect/GPU/value-bounds-op-interface-impl.mlir | 9 +++++++++
mlir/test/Dialect/Vector/test-scalable-bounds.mlir | 2 +-
3 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp
index cfca5cb04b1d2e..3bb7082daa5a01 100644
--- a/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -35,7 +35,7 @@ struct GpuIdOpInterface
cstr.bound(v) >= range.smin().getSExtValue();
cstr.bound(v) <= range.smax().getSExtValue();
};
- // No arguments, so we don't need to pass in their ranges.
+ assert(inferrable->getNumOperands() == 0 && "ID ops have no operands");
inferrable.inferResultRanges({}, translateConstraint);
}
};
diff --git a/mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir
index 719d8a94c982ba..6facf1e22aab90 100644
--- a/mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/GPU/value-bounds-op-interface-impl.mlir
@@ -62,6 +62,9 @@ func.func @launch_func(%arg0 : index) {
"test.compare"(%thread_id_z, %c0) {cmp = "GE"} : (index, index) -> ()
// expected-remark @below{{true}}
"test.compare"(%thread_id_z, %c2) {cmp = "LT"} : (index, index) -> ()
+
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %block_dim_x) {cmp = "LT"} : (index, index) -> ()
gpu.terminator
}
@@ -114,6 +117,8 @@ module attributes {gpu.container_module} {
// expected-remark @below{{true}}
"test.compare"(%block_dim_x, %c8) {cmp = "EQ"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%thread_id_x, %block_dim_x) {cmp = "LT"} : (index, index) -> ()
gpu.return
}
}
@@ -144,6 +149,10 @@ module attributes {gpu.container_module} {
// expected-remark @below{{true}}
"test.compare"(%thread_id_x, %c8) {cmp = "LT"} : (index, index) -> ()
+ // Note: there isn't a way to express the ID <= size constraint
+ // in this form
+ // expected-error @below{{unknown}}
+ "test.compare"(%thread_id_x, %block_dim_x) {cmp = "LT"} : (index, index) -> ()
gpu.return
}
}
diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
index d264d01f445ff3..ddbd805b1cdec3 100644
--- a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -cse -verify-diagnostics \
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds, cse))' -verify-diagnostics \
// RUN: -verify-diagnostics -split-input-file | FileCheck %s
#map_dim_i = affine_map<(d0)[s0] -> (-d0 + 32400, s0)>
More information about the Mlir-commits
mailing list