[Mlir-commits] [mlir] 66e41a1 - [MLIR][NVVM] Declare InferIntRangeInterface for RangeableRegisterOp (#122263)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 10 01:32:29 PST 2025
Author: Guray Ozen
Date: 2025-01-10T10:32:25+01:00
New Revision: 66e41a1a20f2190a800669028a0e80bd86e735ce
URL: https://github.com/llvm/llvm-project/commit/66e41a1a20f2190a800669028a0e80bd86e735ce
DIFF: https://github.com/llvm/llvm-project/commit/66e41a1a20f2190a800669028a0e80bd86e735ce.diff
LOG: [MLIR][NVVM] Declare InferIntRangeInterface for RangeableRegisterOp (#122263)
Added:
mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 4fd00ff929bd70..50d1a39126ea3e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a2d2102b59dece..0b9097e9bbca2c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -18,6 +18,7 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -134,8 +135,8 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "attr-dict `:` type($res)";
}
-class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
- NVVM_SpecialRegisterOp<mnemonic, traits> {
+class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
+ NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -147,6 +148,17 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
}]>
];
+
+ // Define this method for the InferIntRangeInterface.
+ let extraClassDefinition = [{
+ // Infer the result ranges based on the range attribute.
+ void $cppClass::inferResultRanges(
+ ArrayRef<::mlir::ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
+ }
+ }];
+
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8b09c0f386d6b6..838159d676545d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1158,6 +1158,17 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
}
+/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
+/// have ConstantRangeAttr.
+static void nvvmInferResultRanges(Operation *op, Value result,
+ ArrayRef<::mlir::ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
+ setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
+ rangeAttr.getLower(), rangeAttr.getUpper()});
+ }
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
new file mode 100644
index 00000000000000..fae40dc7806ba6
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
+gpu.module @module{
+ gpu.func @kernel_1() kernel {
+ %tidx = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
+ %tidy = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
+ %tidz = nvvm.read.ptx.sreg.tid.z range <i32, 0, 4> : i32
+ %c64 = arith.constant 64 : i32
+
+ %1 = arith.cmpi sgt, %tidx, %c64 : i32
+ scf.if %1 {
+ gpu.printf "threadidx"
+ }
+ %2 = arith.cmpi sgt, %tidy, %c64 : i32
+ scf.if %2 {
+ gpu.printf "threadidy"
+ }
+ %3 = arith.cmpi sgt, %tidz, %c64 : i32
+ scf.if %3 {
+ gpu.printf "threadidz"
+ }
+ gpu.return
+ }
+}
+
+// CHECK-LABEL: gpu.func @kernel_1
+// CHECK: %[[false:.+]] = arith.constant false
+// CHECK: %[[c64_i32:.+]] = arith.constant 64 : i32
+// CHECK: %[[S0:.+]] = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
+// CHECK: scf.if %[[false]] {
+// CHECK: gpu.printf "threadidx"
+// CHECK: %[[S1:.+]] = arith.cmpi sgt, %[[S0]], %[[c64_i32]] : i32
+// CHECK: scf.if %[[S1]] {
+// CHECK: gpu.printf "threadidy"
+// CHECK: scf.if %[[false]] {
+// CHECK: gpu.printf "threadidz"
More information about the Mlir-commits
mailing list