[flang-commits] [flang] [flang][CUF] Limit LICM for cuf.kernel. (PR #178073)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Mon Jan 26 14:58:26 PST 2026
https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/178073
This patch prevents hoisting of operations with reference operands.
Such a hoisting may break the assumptions that later CUF passes
rely on.
>From d3b8111eb7b85a38fc1f2f3d1a8528a2087354ad Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 26 Jan 2026 14:54:44 -0800
Subject: [PATCH] [flang][CUF] Limit LICM for cuf.kernel.
This patch prevents hoisting of operations with reference operands.
Such a hoisting may break the assumptions that later CUF passes
rely on.
---
.../flang/Optimizer/Dialect/CUF/CUFOps.h | 1 +
.../flang/Optimizer/Dialect/CUF/CUFOps.td | 9 +-
flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp | 20 +++++
flang/test/Transforms/CUF/cuf-kernel-licm.fir | 82 +++++++++++++++++++
4 files changed, 109 insertions(+), 3 deletions(-)
create mode 100644 flang/test/Transforms/CUF/cuf-kernel-licm.fir
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
index 1edded090f8ce..d63d6142e5d66 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
@@ -11,6 +11,7 @@
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 34ac21c51b933..2bde3ac00a439 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -16,8 +16,9 @@
include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
-include "flang/Optimizer/Dialect/FIRTypes.td"
include "flang/Optimizer/Dialect/FIRAttr.td"
+include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.td"
+include "flang/Optimizer/Dialect/FIRTypes.td"
include "mlir/Dialect/GPU/IR/GPUBase.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/LoopLikeInterface.td"
@@ -243,8 +244,10 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
let hasVerifier = 1;
}
-def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+def cuf_KernelOp
+ : cuf_Op<"kernel", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface>,
+ DeclareOpInterfaceMethods<OperationMoveOpInterface>]> {
let description = [{
Represent the CUDA Fortran kernel directive. The operation is a loop like
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 97f7f76a8fbe7..9fa4f6bef404f 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -274,6 +274,26 @@ llvm::LogicalResult cuf::KernelOp::verify() {
return checkStreamType(*this);
}
+bool cuf::KernelOp::canMoveFromDescendant(mlir::Operation *descendant,
+ mlir::Operation *candidate) {
+ // Moving operations out of loops inside cuf.kernel is always legal.
+ return true;
+}
+
+bool cuf::KernelOp::canMoveOutOf(mlir::Operation *candidate) {
+ // In general, some movement of operationds out cuf.kernel is allowed.
+ if (!candidate)
+ return true;
+
+ // Operations that have !fir.ref operands cannot be moved
+ // out of cuf.kernel, because this may break implicit data mapping
+ // passes that may run after LICM.
+ return !llvm::any_of(candidate->getOperands(),
+ [&](mlir::Value candidateOperand) {
+ return fir::isa_ref_type(candidateOperand.getType());
+ });
+}
+
//===----------------------------------------------------------------------===//
// RegisterKernelOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Transforms/CUF/cuf-kernel-licm.fir b/flang/test/Transforms/CUF/cuf-kernel-licm.fir
new file mode 100644
index 0000000000000..7d70b934f321b
--- /dev/null
+++ b/flang/test/Transforms/CUF/cuf-kernel-licm.fir
@@ -0,0 +1,82 @@
+// RUN: fir-opt -flang-licm --split-input-file %s | FileCheck %s
+
+// Verify that Pure fir.convert operations with !fir.ref operands
+// are not hoisted by LICM out of cuf.kernel.
+// CHECK-LABEL: func.func @_QPtest(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<10xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"},
+// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<!fir.array<10xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "ra"}) {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index
+// CHECK: %[[CONSTANT_2:.*]] = arith.constant 10 : index
+// CHECK: %[[ALLOC_0:.*]] = cuf.alloc !fir.box<!fir.ptr<!fir.array<?xf32>>> {data_attr = #cuf.cuda<device>} -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[DUMMY_SCOPE_0:.*]] = fir.dummy_scope : !fir.dscope
+// CHECK: %[[SHAPE_0:.*]] = fir.shape %[[CONSTANT_2]] : (index) -> !fir.shape<1>
+// CHECK: %[[DECLARE_0:.*]] = fir.declare %[[ARG0]](%[[SHAPE_0]]) dummy_scope %[[DUMMY_SCOPE_0]] arg 1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> !fir.ref<!fir.array<10xf32>>
+// CHECK: %[[ALLOCA_0:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtestEi"}
+// CHECK: %[[DECLARE_1:.*]] = fir.declare %[[ALLOCA_0]] {uniq_name = "_QFtestEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+// CHECK: %[[ALLOCA_1:.*]] = fir.alloca i64 {bindc_name = "pa", uniq_name = "_QFtestEpa"}
+// CHECK: %[[DECLARE_2:.*]] = fir.declare %[[ALLOCA_1]] {fortran_attrs = #fir.var_attrs<cray_pointer>, uniq_name = "_QFtestEpa"} : (!fir.ref<i64>) -> !fir.ref<i64>
+// CHECK: %[[DECLARE_3:.*]] = fir.declare %[[ALLOC_0]] {fortran_attrs = #fir.var_attrs<pointer, cray_pointee>, uniq_name = "_QFtestEra"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[ZERO_BITS_0:.*]] = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+// CHECK: %[[EMBOX_0:.*]] = fir.embox %[[ZERO_BITS_0]](%[[SHAPE_0]]) : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+// CHECK: fir.store %[[EMBOX_0]] to %[[DECLARE_3]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[DECLARE_0]] : (!fir.ref<!fir.array<10xf32>>) -> i64
+// CHECK: fir.store %[[CONVERT_0]] to %[[DECLARE_2]] : !fir.ref<i64>
+// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : index
+// CHECK: cuf.kernel<<<*, *>>> (%[[VAL_0:.*]] : index) = (%[[CONSTANT_1]] : index) to (%[[CONSTANT_2]] : index) step (%[[CONSTANT_1]] : index) {
+// CHECK: %[[CONVERT_1:.*]] = fir.convert %[[VAL_0]] : (index) -> i32
+// CHECK: fir.store %[[CONVERT_1]] to %[[DECLARE_1]] : !fir.ref<i32>
+// CHECK: %[[CONVERT_2:.*]] = fir.convert %[[DECLARE_2]] : (!fir.ref<i64>) -> !fir.ref<!fir.ptr<i64>>
+// CHECK: %[[LOAD_0:.*]] = fir.load %[[CONVERT_2]] : !fir.ref<!fir.ptr<i64>>
+// CHECK: %[[CONVERT_3:.*]] = fir.convert %[[DECLARE_3]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %[[CONVERT_4:.*]] = fir.convert %[[LOAD_0]] : (!fir.ptr<i64>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranAPointerAssociateScalar(%[[CONVERT_3]], %[[CONVERT_4]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>) -> ()
+// CHECK: %[[LOAD_1:.*]] = fir.load %[[DECLARE_3]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[LOAD_2:.*]] = fir.load %[[DECLARE_1]] : !fir.ref<i32>
+// CHECK: %[[CONVERT_5:.*]] = fir.convert %[[LOAD_2]] : (i32) -> i64
+// CHECK: %[[BOX_DIMS_0:.*]]:3 = fir.box_dims %[[LOAD_1]], %[[CONSTANT_3]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
+// CHECK: %[[SHIFT_0:.*]] = fir.shift %[[BOX_DIMS_0]]#0 : (index) -> !fir.shift<1>
+// CHECK: %[[ARRAY_COOR_0:.*]] = fir.array_coor %[[LOAD_1]](%[[SHIFT_0]]) %[[CONVERT_5]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, !fir.shift<1>, i64) -> !fir.ref<f32>
+// CHECK: fir.store %[[CONSTANT_0]] to %[[ARRAY_COOR_0]] : !fir.ref<f32>
+// CHECK: "fir.end"() : () -> ()
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @_QPtest(%arg0: !fir.ref<!fir.array<10xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %arg1: !fir.ref<!fir.array<10xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "ra"}) {
+ %cst = arith.constant 1.000000e+00 : f32
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %0 = cuf.alloc !fir.box<!fir.ptr<!fir.array<?xf32>>> {data_attr = #cuf.cuda<device>} -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %1 = fir.dummy_scope : !fir.dscope
+ %2 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %3 = fir.declare %arg0(%2) dummy_scope %1 arg 1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> !fir.ref<!fir.array<10xf32>>
+ %4 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtestEi"}
+ %5 = fir.declare %4 {uniq_name = "_QFtestEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
+ %6 = fir.alloca i64 {bindc_name = "pa", uniq_name = "_QFtestEpa"}
+ %7 = fir.declare %6 {fortran_attrs = #fir.var_attrs<cray_pointer>, uniq_name = "_QFtestEpa"} : (!fir.ref<i64>) -> !fir.ref<i64>
+ %8 = fir.declare %0 {fortran_attrs = #fir.var_attrs<pointer, cray_pointee>, uniq_name = "_QFtestEra"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %9 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+ %10 = fir.embox %9(%2) : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ fir.store %10 to %8 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %11 = fir.convert %3 : (!fir.ref<!fir.array<10xf32>>) -> i64
+ fir.store %11 to %7 : !fir.ref<i64>
+ cuf.kernel<<<*, *>>> (%arg2 : index) = (%c1 : index) to (%c10 : index) step (%c1 : index) {
+ %12 = fir.convert %arg2 : (index) -> i32
+ fir.store %12 to %5 : !fir.ref<i32>
+ %13 = fir.convert %7 : (!fir.ref<i64>) -> !fir.ref<!fir.ptr<i64>>
+ %14 = fir.load %13 : !fir.ref<!fir.ptr<i64>>
+ %15 = fir.convert %8 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+ %16 = fir.convert %14 : (!fir.ptr<i64>) -> !fir.llvm_ptr<i8>
+ fir.call @_FortranAPointerAssociateScalar(%15, %16) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>) -> ()
+ %17 = fir.load %8 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %18 = fir.load %5 : !fir.ref<i32>
+ %19 = fir.convert %18 : (i32) -> i64
+ %c0 = arith.constant 0 : index
+ %20:3 = fir.box_dims %17, %c0 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
+ %21 = fir.shift %20#0 : (index) -> !fir.shift<1>
+ %22 = fir.array_coor %17(%21) %19 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, !fir.shift<1>, i64) -> !fir.ref<f32>
+ fir.store %cst to %22 : !fir.ref<f32>
+ "fir.end"() : () -> ()
+ }
+ return
+}
More information about the flang-commits
mailing list