[Mlir-commits] [mlir] [mlir] IntegerRangeAnalysis: add support for vector type (PR #112292)
Ivan Butygin
llvmlistbot at llvm.org
Fri Nov 1 12:05:14 PDT 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/112292
>From c29fe1753728d1e8958ae16a267b62aeb4491a82 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 02:40:36 +0200
Subject: [PATCH 1/6] [mlir] IntegerRangeAnalysis: add support for vector type
Treat integer range for vector type as union of ranges of individual elements.
With this semantics, most arith ops on vectors will work out of the box, the only special handling needed for constants and vector elements manipulation ops.
The end goal of these changes is to optimize vectorized index calculations.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 10 +++-
.../DataFlow/IntegerRangeAnalysis.cpp | 6 +-
.../Arith/IR/InferIntRangeInterfaceImpls.cpp | 21 ++++++-
.../Transforms/IntRangeOptimizations.cpp | 18 ++++--
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 15 +++++
.../Dialect/Vector/int-range-interface.mlir | 57 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 5 +-
mlir/test/lib/Dialect/Test/TestOps.td | 2 +-
8 files changed, 118 insertions(+), 16 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/int-range-interface.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b54a8b7fe8680dc..f02f7bb599378a6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -13,20 +13,21 @@
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
-include "mlir/Dialect/Vector/IR/Vector.td"
-include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
-include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/Vector/IR/Vector.td"
+include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
+include "mlir/IR/EnumAttr.td"
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
@@ -627,6 +628,7 @@ def Vector_DeinterleaveOp :
def Vector_ExtractElementOp :
Vector_Op<"extractelement", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
"::llvm::cast<VectorType>($_self).getElementType()">]>,
@@ -673,6 +675,7 @@ def Vector_ExtractElementOp :
def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
InferTypeOpAdaptorWithIsCompatible]> {
@@ -2801,6 +2804,7 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
def Vector_SplatOp : Vector_Op<"splat", [
Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"::llvm::cast<VectorType>($_self).getElementType()">
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index bf9eabbedc3a1fe..a97e43708d9a37d 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -53,9 +54,10 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
dialect = parent->getDialect();
else
dialect = value.getParentBlock()->getParentOp()->getDialect();
+
+ Type type = getElementTypeOrSelf(value);
solver->propagateIfChanged(
- cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
- dialect)));
+ cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
}
LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 462044417b5fb87..3df483a4d2ddd02 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -35,10 +35,27 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
- if (constAttr) {
+ if (auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
const APInt &value = constAttr.getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
+ return;
+ }
+ if (auto constAttr =
+ llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
+ std::optional<ConstantIntRanges> result;
+ for (APInt &&val : constAttr) {
+ auto range = ConstantIntRanges::constant(val);
+ if (!result) {
+ result = range;
+ } else {
+ result = result->rangeUnion(range);
+ }
+ }
+
+ if (result)
+ setResultRange(getResult(), *result);
+
+ return;
}
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd8..d494bba081f801e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -51,21 +51,27 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
if (!maybeConstValue.has_value())
return failure();
+ Type type = value.getType();
+ Location loc = value.getLoc();
Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
- Attribute constAttr =
- rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
- Operation *constOp = valueDialect->materializeConstant(
- rewriter, constAttr, value.getType(), value.getLoc());
+
+ Attribute constAttr;
+ if (auto shaped = dyn_cast<ShapedType>(type)) {
+ constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
+ } else {
+ constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
+ }
+ Operation *constOp =
+ valueDialect->materializeConstant(rewriter, constAttr, type, loc);
// Fall back to arith.constant if the dialect materializer doesn't know what
// to do with an integer constant.
if (!constOp)
constOp = rewriter.getContext()
->getLoadedDialect<ArithDialect>()
- ->materializeConstant(rewriter, constAttr, value.getType(),
- value.getLoc());
+ ->materializeConstant(rewriter, constAttr, type, loc);
if (!constOp)
return failure();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5d018bdbe0b2479..68f41c1a180c57f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1221,6 +1221,11 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExtractElementOp
//===----------------------------------------------------------------------===//
+void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source) {
result.addOperands({source});
@@ -1273,6 +1278,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// ExtractOp
//===----------------------------------------------------------------------===//
+void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, int64_t position) {
build(builder, result, source, ArrayRef<int64_t>{position});
@@ -6423,6 +6433,11 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
//===----------------------------------------------------------------------===//
// WarpExecuteOnLane0Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
new file mode 100644
index 000000000000000..0fac3c417b7badc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
+
+
+// CHECK-LABEL: func @constant_vec
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @constant_vec() -> vector<8xindex> {
+ %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+ %1 = test.reflect_bounds %0 : vector<8xindex>
+ func.return %1 : vector<8xindex>
+}
+
+// CHECK-LABEL: func @constant_splat
+// CHECK: test.reflect_bounds {smax = 3 : si32, smin = 3 : si32, umax = 3 : ui32, umin = 3 : ui32}
+func.func @constant_splat() -> vector<8xi32> {
+ %0 = arith.constant dense<3> : vector<8xi32>
+ %1 = test.reflect_bounds %0 : vector<8xi32>
+ func.return %1 : vector<8xi32>
+}
+
+
+// CHECK-LABEL: func @vector_splat
+// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
+func.func @vector_splat() -> vector<4xindex> {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+ %1 = vector.splat %0 : vector<4xindex>
+ %2 = test.reflect_bounds %1 : vector<4xindex>
+ func.return %2 : vector<4xindex>
+}
+
+// CHECK-LABEL: func @vector_extract
+// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
+func.func @vector_extract() -> index {
+ %0 = test.with_bounds { umin = 5 : index, umax = 6 : index, smin = 5 : index, smax = 6 : index } : vector<4xindex>
+ %1 = vector.extract %0[0] : index from vector<4xindex>
+ %2 = test.reflect_bounds %1 : index
+ func.return %2 : index
+}
+
+// CHECK-LABEL: func @vector_extractelement
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
+func.func @vector_extractelement() -> index {
+ %c0 = arith.constant 0 : index
+ %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
+ %1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
+ %2 = test.reflect_bounds %1 : index
+ func.return %2 : index
+}
+
+// CHECK-LABEL: func @vector_add
+// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
+func.func @vector_add() -> vector<4xindex> {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
+ %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
+ %2 = arith.addi %0, %1 : vector<4xindex>
+ %3 = test.reflect_bounds %2 : vector<4xindex>
+ func.return %3 : vector<4xindex>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 69091fb893fad60..b268e549b93ab6a 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -760,12 +760,13 @@ void TestReflectBoundsOp::inferResultRanges(
Type sIntTy, uIntTy;
// For plain `IntegerType`s, we can derive the appropriate signed and unsigned
// Types for the Attributes.
- if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
+ Type type = getElementTypeOrSelf(getType());
+ if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
unsigned bitwidth = intTy.getWidth();
sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
} else
- sIntTy = uIntTy = getType();
+ sIntTy = uIntTy = type;
setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bc6c6cf213ea4b6..e51778c4433889b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2781,7 +2781,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
//===----------------------------------------------------------------------===//
// Test InferIntRangeInterface
//===----------------------------------------------------------------------===//
-def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
+def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>;
def TestWithBoundsOp : TEST_Op<"with_bounds",
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
>From b9a4ccde430e0efbdd31dce41356f424df558442 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 26 Oct 2024 13:08:40 +0200
Subject: [PATCH 2/6] add some more ops
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 3 ++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 15 ++++++++++
.../Dialect/Vector/int-range-interface.mlir | 30 +++++++++++++++++++
3 files changed, 48 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f02f7bb599378a6..cda14c0c69da145 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyType:$source)>,
@@ -813,6 +814,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
def Vector_InsertElementOp :
Vector_Op<"insertelement", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"source operand type matches element type of result",
"result", "source",
"::llvm::cast<VectorType>($_self).getElementType()">,
@@ -861,6 +863,7 @@ def Vector_InsertElementOp :
def Vector_InsertOp :
Vector_Op<"insert", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "result"]>]> {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 68f41c1a180c57f..4f94d6a6fabdb69 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2262,6 +2262,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
// BroadcastOp
//===----------------------------------------------------------------------===//
+void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
@@ -2723,6 +2728,11 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
// InsertElementOp
//===----------------------------------------------------------------------===//
+void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
+}
+
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest) {
build(builder, result, source, dest, {});
@@ -2772,6 +2782,11 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
+void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
+}
+
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, int64_t position) {
build(builder, result, source, dest, ArrayRef<int64_t>{position});
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 0fac3c417b7badc..6c9b90e99017508 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -27,6 +27,15 @@ func.func @vector_splat() -> vector<4xindex> {
func.return %2 : vector<4xindex>
}
+// CHECK-LABEL: func @vector_broadcast
+// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
+func.func @vector_broadcast() -> vector<4x16xindex> {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
+ %1 = vector.broadcast %0 : vector<16xindex> to vector<4x16xindex>
+ %2 = test.reflect_bounds %1 : vector<4x16xindex>
+ func.return %2 : vector<4x16xindex>
+}
+
// CHECK-LABEL: func @vector_extract
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
func.func @vector_extract() -> index {
@@ -55,3 +64,24 @@ func.func @vector_add() -> vector<4xindex> {
%3 = test.reflect_bounds %2 : vector<4xindex>
func.return %3 : vector<4xindex>
}
+
+// CHECK-LABEL: func @vector_insert
+// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
+func.func @vector_insert() -> vector<4xindex> {
+ %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
+ %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
+ %2 = vector.insert %1, %0[0] : index into vector<4xindex>
+ %3 = test.reflect_bounds %2 : vector<4xindex>
+ func.return %3 : vector<4xindex>
+}
+
+// CHECK-LABEL: func @vector_insertelement
+// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
+func.func @vector_insertelement() -> vector<4xindex> {
+ %c0 = arith.constant 0 : index
+ %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
+ %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
+ %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex>
+ %3 = test.reflect_bounds %2 : vector<4xindex>
+ func.return %3 : vector<4xindex>
+}
>From 506cf834847cb4a31330fc2a9486f60f6e09d9ac Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 26 Oct 2024 13:20:33 +0200
Subject: [PATCH 3/6] shape_cast
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 4 +++-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++++
mlir/test/Dialect/Vector/int-range-interface.mlir | 9 +++++++++
3 files changed, 17 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cda14c0c69da145..3f45d0804e04505 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2210,7 +2210,9 @@ def Vector_CompressStoreOp :
}
def Vector_ShapeCastOp :
- Vector_Op<"shape_cast", [Pure]>,
+ Vector_Op<"shape_cast", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
+ ]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4f94d6a6fabdb69..d8913251e56e9ee 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5302,6 +5302,11 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ShapeCastOp
//===----------------------------------------------------------------------===//
+void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 6c9b90e99017508..507c4722cbd1c65 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -36,6 +36,15 @@ func.func @vector_broadcast() -> vector<4x16xindex> {
func.return %2 : vector<4x16xindex>
}
+// CHECK-LABEL: func @vector_shape_cast
+// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
+func.func @vector_shape_cast() -> vector<4x4xindex> {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
+ %1 = vector.shape_cast %0 : vector<16xindex> to vector<4x4xindex>
+ %2 = test.reflect_bounds %1 : vector<4x4xindex>
+ func.return %2 : vector<4x4xindex>
+}
+
// CHECK-LABEL: func @vector_extract
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
func.func @vector_extract() -> index {
>From 9dca25d3306df82862cec63274ed07bfb61c226a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 31 Oct 2024 15:58:08 +0100
Subject: [PATCH 4/6] review fixes
---
.../Arith/IR/InferIntRangeInterfaceImpls.cpp | 19 +++++++------------
.../Dialect/Vector/int-range-interface.mlir | 1 -
2 files changed, 7 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 3df483a4d2ddd02..8682294c8a6972b 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -35,26 +35,21 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- if (auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
- const APInt &value = constAttr.getValue();
+ if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
+ const APInt &value = scalarCstAttr.getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
return;
}
- if (auto constAttr =
+ if (auto arrayCstAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
std::optional<ConstantIntRanges> result;
- for (APInt &&val : constAttr) {
+ for (const APInt &val : arrayCstAttr) {
auto range = ConstantIntRanges::constant(val);
- if (!result) {
- result = range;
- } else {
- result = result->rangeUnion(range);
- }
+ result = (result ? result->rangeUnion(range) : range);
}
- if (result)
- setResultRange(getResult(), *result);
-
+ assert(result && "Zero-sized vectors are not allowed");
+ setResultRange(getResult(), *result);
return;
}
}
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 507c4722cbd1c65..3c5dfc67dea34fc 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -17,7 +17,6 @@ func.func @constant_splat() -> vector<8xi32> {
func.return %1 : vector<8xi32>
}
-
// CHECK-LABEL: func @vector_splat
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_splat() -> vector<4xindex> {
>From f879041f35a67da4d00eae1f6d3934bf639a6f67 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 1 Nov 2024 00:15:07 +0100
Subject: [PATCH 5/6] docs
---
mlir/test/lib/Dialect/Test/TestOps.td | 22 ++++++++++++++++++++++
1 file changed, 22 insertions(+)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e51778c4433889b..cfe19a2fd5c08b4 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2786,6 +2786,16 @@ def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Inde
def TestWithBoundsOp : TEST_Op<"with_bounds",
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
NoMemoryEffect]> {
+ let description = [{
+ Creates a value with specified [min, max] range for integer range analysis.
+
+ Example:
+
+ ```mlir
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+ ```
+ }];
+
let arguments = (ins APIntAttr:$umin,
APIntAttr:$umax,
APIntAttr:$smin,
@@ -2819,6 +2829,18 @@ def TestIncrementOp : TEST_Op<"increment",
def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
AllTypesMatch<["value", "result"]>]> {
+ let description = [{
+ Integer range analysis will update this op to reflect inferred integer range
+ of the input, so it can be checked with FileCheck
+
+ Example:
+
+ ```mlir
+ CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+ %1 = test.reflect_bounds %0 : index
+ ```
+ }];
+
let arguments = (ins InferIntRangeType:$value,
OptionalAttr<APIntAttr>:$umin,
OptionalAttr<APIntAttr>:$umax,
>From ab9970d0ffe2255ad21465673dadd743df4ac378 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 1 Nov 2024 20:04:35 +0100
Subject: [PATCH 6/6] test
---
mlir/test/Dialect/Vector/int-range-interface.mlir | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 3c5dfc67dea34fc..29282423089ba65 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -93,3 +93,14 @@ func.func @vector_insertelement() -> vector<4xindex> {
%3 = test.reflect_bounds %2 : vector<4xindex>
func.return %3 : vector<4xindex>
}
+
+// CHECK-LABEL: func @test_loaded_vector_extract
+// No bounds
+// CHECK: test.reflect_bounds %{{.*}} : i32
+func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
+ %c0 = arith.constant 0 : index
+ %v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
+ %e = vector.extract %v[0] : i32 from vector<4xi32>
+ %bounds = test.reflect_bounds %e : i32
+ func.return %bounds : i32
+}
More information about the Mlir-commits
mailing list