[Mlir-commits] [mlir] [mlir] IntegerRangeAnalysis: add support for vector type (PR #112292)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 14 18:38:33 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-vector
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Treat integer range for vector type as union of ranges of individual elements. With this semantics, most arith ops on vectors will work out of the box, the only special handling needed for constants and vector elements manipulation ops.
The end goal of these changes is to be able to optimize vectorized index calculations.
---
Full diff: https://github.com/llvm/llvm-project/pull/112292.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+7-3)
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+4-2)
- (modified) mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp (+19-2)
- (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+12-6)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15)
- (modified) mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir (+1-1)
- (added) mlir/test/Dialect/Vector/int-range-interface.mlir (+57)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+3-2)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b0de7c11b9d436..d890d5017daca7 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -13,20 +13,21 @@
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
-include "mlir/Dialect/Vector/IR/Vector.td"
-include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
-include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/Vector/IR/Vector.td"
+include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
+include "mlir/IR/EnumAttr.td"
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
@@ -627,6 +628,7 @@ def Vector_DeinterleaveOp :
def Vector_ExtractElementOp :
Vector_Op<"extractelement", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
"::llvm::cast<VectorType>($_self).getElementType()">]>,
@@ -673,6 +675,7 @@ def Vector_ExtractElementOp :
def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
InferTypeOpAdaptorWithIsCompatible]> {
@@ -2795,6 +2798,7 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
def Vector_SplatOp : Vector_Op<"splat", [
Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"::llvm::cast<VectorType>($_self).getElementType()">
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index bf9eabbedc3a1f..a97e43708d9a37 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -53,9 +54,10 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
dialect = parent->getDialect();
else
dialect = value.getParentBlock()->getParentOp()->getDialect();
+
+ Type type = getElementTypeOrSelf(value);
solver->propagateIfChanged(
- cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
- dialect)));
+ cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
}
LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 462044417b5fb8..3df483a4d2ddd0 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -35,10 +35,27 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
- if (constAttr) {
+ if (auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
const APInt &value = constAttr.getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
+ return;
+ }
+ if (auto constAttr =
+ llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
+ std::optional<ConstantIntRanges> result;
+ for (APInt &&val : constAttr) {
+ auto range = ConstantIntRanges::constant(val);
+ if (!result) {
+ result = range;
+ } else {
+ result = result->rangeUnion(range);
+ }
+ }
+
+ if (result)
+ setResultRange(getResult(), *result);
+
+ return;
}
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd..d494bba081f801 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -51,21 +51,27 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
if (!maybeConstValue.has_value())
return failure();
+ Type type = value.getType();
+ Location loc = value.getLoc();
Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
- Attribute constAttr =
- rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
- Operation *constOp = valueDialect->materializeConstant(
- rewriter, constAttr, value.getType(), value.getLoc());
+
+ Attribute constAttr;
+ if (auto shaped = dyn_cast<ShapedType>(type)) {
+ constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
+ } else {
+ constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
+ }
+ Operation *constOp =
+ valueDialect->materializeConstant(rewriter, constAttr, type, loc);
// Fall back to arith.constant if the dialect materializer doesn't know what
// to do with an integer constant.
if (!constOp)
constOp = rewriter.getContext()
->getLoadedDialect<ArithDialect>()
- ->materializeConstant(rewriter, constAttr, value.getType(),
- value.getLoc());
+ ->materializeConstant(rewriter, constAttr, type, loc);
if (!constOp)
return failure();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..43920cb5cf30d3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1221,6 +1221,11 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExtractElementOp
//===----------------------------------------------------------------------===//
+void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source) {
result.addOperands({source});
@@ -1273,6 +1278,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// ExtractOp
//===----------------------------------------------------------------------===//
+void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, int64_t position) {
build(builder, result, source, ArrayRef<int64_t>{position});
@@ -6423,6 +6433,11 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
//===----------------------------------------------------------------------===//
// StepOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
index 49bd74cfe9124a..9f3d575838f320 100644
--- a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
+++ b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
@@ -100,7 +100,7 @@ func.func @dead_code() {
// Make sure not crash.
// CHECK-LABEL: @no_integer_or_index
func.func @no_integer_or_index() {
- // CHECK: arith.cmpi
+ // CHECK: arith.constant dense<false> : vector<1xi1>
%cst_0 = arith.constant dense<[0]> : vector<1xi32>
%cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32>
return
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
new file mode 100644
index 00000000000000..0fac3c417b7bad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
+
+
+// CHECK-LABEL: func @constant_vec
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @constant_vec() -> vector<8xindex> {
+ %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+ %1 = test.reflect_bounds %0 : vector<8xindex>
+ func.return %1 : vector<8xindex>
+}
+
+// CHECK-LABEL: func @constant_splat
+// CHECK: test.reflect_bounds {smax = 3 : si32, smin = 3 : si32, umax = 3 : ui32, umin = 3 : ui32}
+func.func @constant_splat() -> vector<8xi32> {
+ %0 = arith.constant dense<3> : vector<8xi32>
+ %1 = test.reflect_bounds %0 : vector<8xi32>
+ func.return %1 : vector<8xi32>
+}
+
+
+// CHECK-LABEL: func @vector_splat
+// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
+func.func @vector_splat() -> vector<4xindex> {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+ %1 = vector.splat %0 : vector<4xindex>
+ %2 = test.reflect_bounds %1 : vector<4xindex>
+ func.return %2 : vector<4xindex>
+}
+
+// CHECK-LABEL: func @vector_extract
+// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
+func.func @vector_extract() -> index {
+ %0 = test.with_bounds { umin = 5 : index, umax = 6 : index, smin = 5 : index, smax = 6 : index } : vector<4xindex>
+ %1 = vector.extract %0[0] : index from vector<4xindex>
+ %2 = test.reflect_bounds %1 : index
+ func.return %2 : index
+}
+
+// CHECK-LABEL: func @vector_extractelement
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
+func.func @vector_extractelement() -> index {
+ %c0 = arith.constant 0 : index
+ %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
+ %1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
+ %2 = test.reflect_bounds %1 : index
+ func.return %2 : index
+}
+
+// CHECK-LABEL: func @vector_add
+// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
+func.func @vector_add() -> vector<4xindex> {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
+ %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
+ %2 = arith.addi %0, %1 : vector<4xindex>
+ %3 = test.reflect_bounds %2 : vector<4xindex>
+ func.return %3 : vector<4xindex>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 69091fb893fad6..b268e549b93ab6 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -760,12 +760,13 @@ void TestReflectBoundsOp::inferResultRanges(
Type sIntTy, uIntTy;
// For plain `IntegerType`s, we can derive the appropriate signed and unsigned
// Types for the Attributes.
- if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
+ Type type = getElementTypeOrSelf(getType());
+ if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
unsigned bitwidth = intTy.getWidth();
sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
} else
- sIntTy = uIntTy = getType();
+ sIntTy = uIntTy = type;
setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9e19966414d1d7..301f55c670d752 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2781,7 +2781,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
//===----------------------------------------------------------------------===//
// Test InferIntRangeInterface
//===----------------------------------------------------------------------===//
-def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
+def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>;
def TestWithBoundsOp : TEST_Op<"with_bounds",
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/112292
More information about the Mlir-commits
mailing list