[Mlir-commits] [mlir] [mlir] IntegerRangeAnalysis: don't loop over splat attr (PR #115229)

Ian Wood llvmlistbot at llvm.org
Wed Nov 6 14:53:13 PST 2024


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/115229

>From 9f2efedfd9ecb5f1e9675678f97424caec4ed0f1 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 7 Nov 2024 02:33:40 -0800
Subject: [PATCH 1/2] IntegerRangeAnalysis: dont loop over splat attr

---
 mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 8682294c8a6972..6646c189eb1c3e 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -42,6 +43,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   }
   if (auto arrayCstAttr =
           llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
+    if (arrayCstAttr.isSplat()) {
+      setResultRange(getResult(), ConstantIntRanges::constant(
+                                      arrayCstAttr.getSplatValue<APInt>()));
+      return;
+    }
+
     std::optional<ConstantIntRanges> result;
     for (const APInt &val : arrayCstAttr) {
       auto range = ConstantIntRanges::constant(val);

>From aa13a0847c6c40eb1b3fea3e5981fd0899761c76 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 7 Nov 2024 02:49:07 -0800
Subject: [PATCH 2/2] Move assert before early return

---
 mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 6646c189eb1c3e..6a14a03e89d6b9 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -43,6 +42,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   }
   if (auto arrayCstAttr =
           llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
+    assert(arrayCstAttr.size() && "Zero-sized vectors are not allowed");
     if (arrayCstAttr.isSplat()) {
       setResultRange(getResult(), ConstantIntRanges::constant(
                                       arrayCstAttr.getSplatValue<APInt>()));
@@ -55,7 +55,6 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
       result = (result ? result->rangeUnion(range) : range);
     }
 
-    assert(result && "Zero-sized vectors are not allowed");
     setResultRange(getResult(), *result);
     return;
   }



More information about the Mlir-commits mailing list