[Mlir-commits] [mlir] [MLIR][NVVM] Declare InferIntRangeInterface for RangeableRegisterOp (PR #122263)

Guray Ozen llvmlistbot at llvm.org
Thu Jan 9 04:08:02 PST 2025


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/122263

`SpecialRangeableRegister` NVVM operations may have a `range` attribute set. When this attribute is present, it becomes possible to determine their range.

This PR declares the `InferIntRangeInterface` for all `SpecialRangeableRegister` operations and extracts range information from the `range` attribute.

>From 0c7d2661bb24eeaca34997896c1d0aed34c30b5f Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Thu, 9 Jan 2025 13:06:47 +0100
Subject: [PATCH] [MLIR][NVVM] Declare `InferIntRangeInterface` for
 `SpecialRangeableRegisterOp`

`SpecialRangeableRegister` NVVM operations may have a `range` attribute set. When this attribute is present, it becomes possible to determine their range.

This PR declares the `InferIntRangeInterface` for all `SpecialRangeableRegister` operations and extracts range information from the `range` attribute.
---
 .../include/mlir/Dialect/LLVMIR/NVVMDialect.h |  1 +
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 20 +++++++++++--
 mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir | 28 +++++++++++++++++++
 3 files changed, 47 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 4fd00ff929bd70..50d1a39126ea3e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a2d2102b59dece..d0d720e664ce3b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -18,6 +18,7 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 
 def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
 def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -134,8 +135,8 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
-  NVVM_SpecialRegisterOp<mnemonic, traits> {
+class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
+  NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
   let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
   let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -147,6 +148,21 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
       build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
     }]>
   ];
+
+  // Define this method for the InferIntRangeInterface.
+  let extraClassDefinition = [{
+    // Infer the result ranges based on the range attribute.
+    void $cppClass::inferResultRanges(
+        ArrayRef<::mlir::ConstantIntRanges> argRanges,
+        SetIntRangeFn setResultRanges) {
+        if (auto rangeAttr = getOperation()->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
+          setResultRanges(getResult(), 
+                          {rangeAttr.getLower(), rangeAttr.getUpper(),
+                          rangeAttr.getLower(), rangeAttr.getUpper()});
+        }
+    }
+  }];
+
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
new file mode 100644
index 00000000000000..7014ff8aaa80b7
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
+gpu.module @module{
+    gpu.func @kernel_1() kernel {
+        %tidx = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
+        %tidy = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
+        %tidz = nvvm.read.ptx.sreg.tid.z range <i32, 0, 4> : i32
+        %c64 = arith.constant 64 : i32
+        
+        %1 = arith.cmpi sgt, %tidx, %c64 : i32
+        scf.if %1 {
+            gpu.printf "threadidx"
+        }
+        %2 = arith.cmpi sgt, %tidy, %c64 : i32
+        scf.if %2 {
+            gpu.printf "threadidy"
+        }
+        %3 = arith.cmpi sgt, %tidz, %c64 : i32
+        scf.if %3 {
+            gpu.printf "threadidz"
+        }
+        gpu.return
+    }
+}
+
+// CHECK-LABEL: gpu.func @kernel_1
+// CHECK-NOT: gpu.printf "threadidx"
+// CHECK: gpu.printf "threadidy"
+// CHECK-NOT: gpu.printf "threadidz"
\ No newline at end of file



More information about the Mlir-commits mailing list