[Mlir-commits] [llvm] [mlir] [mlir][Interfaces][NFC] Add TableGen test op for value bounds tests (PR #88717)
Matthias Springer
llvmlistbot at llvm.org
Mon Apr 15 04:53:55 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/88717
This commit is a code cleanup. It defines the test ops the are used for the `ValueBoundsOpInterface` tests in TableGen, along with proper verifiers.
>From b59af1d57e1a10cb1baee220bb45156ab0372cee Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 15 Apr 2024 11:50:20 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] Add TableGen test op for value bounds
tests
This commit is a code cleanup. It defines the test ops the are used for the `ValueBoundsOpInterface` tests in TableGen, along with proper verifiers.
---
.../value-bounds-op-interface-impl.mlir | 2 +-
.../Affine/value-bounds-reification.mlir | 6 +-
.../value-bounds-op-interface-impl.mlir | 4 +-
.../Dialect/Vector/test-scalable-bounds.mlir | 18 +-
mlir/test/lib/Dialect/Affine/CMakeLists.txt | 1 +
.../Dialect/Affine/TestReifyValueBounds.cpp | 290 +++++++-----------
mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 +
mlir/test/lib/Dialect/Test/TestDialect.cpp | 48 +++
mlir/test/lib/Dialect/Test/TestDialect.h | 1 +
mlir/test/lib/Dialect/Test/TestOps.td | 47 +++
.../mlir/test/BUILD.bazel | 2 +
11 files changed, 219 insertions(+), 201 deletions(-)
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 10da91870f49d9..23c6872dcebe94 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -74,7 +74,7 @@ func.func @composed_affine_apply(%i1 : index) -> (index) {
%i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
%i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
%s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
- %reified = "test.reify_constant_bound"(%s) {type = "EQ"} : (index) -> (index)
+ %reified = "test.reify_bound"(%s) {type = "EQ", constant} : (index) -> (index)
return %reified : index
}
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 909c9098c51607..75622f59af83be 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -47,7 +47,7 @@ func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound) : (index) -> ()
- %bound_const = "test.reify_constant_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+ %bound_const = "test.reify_bound"(%filled) {dim = 1, type = "UB", constant} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound_const) : (index) -> ()
}
return
@@ -93,7 +93,7 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
// CHECK: %[[c129:.*]] = arith.constant 129 : index
// CHECK: "test.some_use"(%[[c129]])
- %lb1_ub_const = "test.reify_constant_bound"(%lb1) {type = "UB"} : (index) -> (index)
+ %lb1_ub_const = "test.reify_bound"(%lb1) {type = "UB", constant} : (index) -> (index)
"test.some_use"(%lb1_ub_const) : (index) -> ()
scf.for %iv1 = %lb1 to %ub1 step %c32 {
@@ -116,7 +116,7 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
// CHECK: %[[c32:.*]] = arith.constant 32 : index
// CHECK: "test.some_use"(%[[c32]])
- %matmul_ub_const = "test.reify_constant_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+ %matmul_ub_const = "test.reify_bound"(%matmul) {dim = 1, type = "UB", constant} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%matmul_ub_const) : (index) -> ()
}
}
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 0c90bcdb42028c..0ba9983723a0a1 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -83,7 +83,7 @@ func.func @extract_slice_static(%t: tensor<?xf32>) -> index {
func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> index {
%0 = tensor.extract_slice %t[2][%sz][1] : tensor<?xf32> to tensor<?xf32>
// expected-error @below{{could not reify bound}}
- %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+ %1 = "test.reify_bound"(%0) {dim = 0, constant} : (tensor<?xf32>) -> (index)
return %1 : index
}
@@ -95,7 +95,7 @@ func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> inde
// CHECK: return %[[c5]]
func.func @extract_slice_static_constant(%t: tensor<?xf32>) -> index {
%0 = tensor.extract_slice %t[2][5][1] : tensor<?xf32> to tensor<5xf32>
- %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index)
+ %1 = "test.reify_bound"(%0) {dim = 0, constant} : (tensor<5xf32>) -> (index)
return %1 : index
}
diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
index 245a6f5c13ac3d..d549c5bd1c3785 100644
--- a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -26,8 +26,8 @@ func.func @fixed_size_loop_nest() {
%min_i = affine.min #map_dim_i(%i)[%c4_vscale]
scf.for %j = %c0 to %c16 step %c4_vscale {
%min_j = affine.min #map_dim_j(%j)[%c4_vscale]
- %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
- %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
+ %bound_j = "test.reify_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
"test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
}
}
@@ -58,8 +58,8 @@ func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
%min_i = affine.min #map_dynamic_dim(%i)[%c4_vscale, %dim0]
scf.for %j = %c0 to %dim1 step %c4_vscale {
%min_j = affine.min #map_dynamic_dim(%j)[%c4_vscale, %dim1]
- %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
- %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
+ %bound_j = "test.reify_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
"test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
}
}
@@ -80,7 +80,7 @@ func.func @add_to_vscale() {
%vscale = vector.vscale
%c8 = arith.constant 8 : index
%vscale_plus_c8 = arith.addi %vscale, %c8 : index
- %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound = "test.reify_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
"test.some_use"(%bound) : (index) -> ()
return
}
@@ -94,7 +94,7 @@ func.func @add_to_vscale() {
// CHECK: "test.some_use"(%[[C2]]) : (index) -> ()
func.func @vscale_fixed_size() {
%vscale = vector.vscale
- %bound = "test.reify_scalable_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2} : (index) -> index
+ %bound = "test.reify_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2, scalable} : (index) -> index
"test.some_use"(%bound) : (index) -> ()
return
}
@@ -107,7 +107,7 @@ func.func @unknown_bound(%a: index) {
%vscale = vector.vscale
%vscale_plus_a = arith.muli %vscale, %a : index
// expected-error @below{{could not reify bound}}
- %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound = "test.reify_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
"test.some_use"(%bound) : (index) -> ()
return
}
@@ -134,7 +134,7 @@ func.func @duplicate_vscale_values() {
%c2_vscale = arith.muli %vscale_1, %c2 : index
%add = arith.addi %c2_vscale, %c4_vscale : index
- %bound = "test.reify_scalable_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound = "test.reify_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
"test.some_use"(%bound) : (index) -> ()
return
}
@@ -154,7 +154,7 @@ func.func @non_scalable_code() {
%c0 = arith.constant 0 : index
scf.for %i = %c0 to %c1024 step %c4 {
%min_i = affine.min #map_dim_i(%i)[%c4]
- %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+ %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
"test.some_use"(%bound_i) : (index) -> ()
}
return
diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
index 14960a45d39bab..a8af7285573456 100644
--- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
@@ -30,5 +30,6 @@ add_mlir_library(MLIRAffineTransformsTestPasses
MLIRSupport
MLIRMemRefDialect
MLIRTensorDialect
+ MLIRTestDialect
MLIRVectorUtils
)
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index f38631054fb3c1..d6f8c5dabaa49d 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "TestDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
@@ -57,16 +58,6 @@ struct TestReifyValueBounds
} // namespace
-static FailureOr<BoundType> parseBoundType(const std::string &type) {
- if (type == "EQ")
- return BoundType::EQ;
- if (type == "LB")
- return BoundType::LB;
- if (type == "UB")
- return BoundType::UB;
- return failure();
-}
-
static FailureOr<ValueBoundsConstraintSet::ComparisonOperator>
parseComparisonOperator(const std::string &type) {
if (type == "EQ")
@@ -101,144 +92,89 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
bool reifyToFuncArgs,
bool useArithOps) {
IRRewriter rewriter(funcOp.getContext());
- WalkResult result = funcOp.walk([&](Operation *op) {
- // Look for test.reify_bound ops.
- if (op->getName().getStringRef() == "test.reify_bound" ||
- op->getName().getStringRef() == "test.reify_constant_bound" ||
- op->getName().getStringRef() == "test.reify_scalable_bound") {
- if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
- !op->getResultTypes()[0].isIndex()) {
- op->emitOpError("invalid op");
- return WalkResult::skip();
- }
- Value value = op->getOperand(0);
- if (isa<IndexType>(value.getType()) !=
- !op->hasAttrOfType<IntegerAttr>("dim")) {
- // Op should have "dim" attribute if and only if the operand is an
- // index-typed value.
- op->emitOpError("invalid op");
- return WalkResult::skip();
- }
-
- // Get bound type.
- std::string boundTypeStr = "EQ";
- if (auto boundTypeAttr = op->getAttrOfType<StringAttr>("type"))
- boundTypeStr = boundTypeAttr.str();
- auto boundType = parseBoundType(boundTypeStr);
- if (failed(boundType)) {
- op->emitOpError("invalid op");
- return WalkResult::interrupt();
- }
-
- // Get shape dimension (if any).
- auto dim = value.getType().isIndex()
- ? std::nullopt
- : std::make_optional<int64_t>(
- op->getAttrOfType<IntegerAttr>("dim").getInt());
-
- // Check if a constant was requested.
- bool constant =
- op->getName().getStringRef() == "test.reify_constant_bound";
-
- bool scalable = !constant && op->getName().getStringRef() ==
- "test.reify_scalable_bound";
-
- // Prepare stop condition. By default, reify in terms of the op's
- // operands. No stop condition is used when a constant was requested.
- std::function<bool(Value, std::optional<int64_t>,
- ValueBoundsConstraintSet & cstr)>
- stopCondition = [&](Value v, std::optional<int64_t> d,
- ValueBoundsConstraintSet &cstr) {
- // Reify in terms of SSA values that are different from `value`.
- return v != value;
- };
- if (reifyToFuncArgs) {
- // Reify in terms of function block arguments.
- stopCondition = [](Value v, std::optional<int64_t> d,
- ValueBoundsConstraintSet &cstr) {
- auto bbArg = dyn_cast<BlockArgument>(v);
- if (!bbArg)
- return false;
- return isa<FunctionOpInterface>(
- bbArg.getParentBlock()->getParentOp());
+ WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
+ auto boundType = op.getBoundType();
+ Value value = op.getVar();
+ std::optional<int64_t> dim = op.getDim();
+ bool constant = op.getConstant();
+ bool scalable = op.getScalable();
+
+ // Prepare stop condition. By default, reify in terms of the op's
+ // operands. No stop condition is used when a constant was requested.
+ std::function<bool(Value, std::optional<int64_t>,
+ ValueBoundsConstraintSet & cstr)>
+ stopCondition = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
+ // Reify in terms of SSA values that are different from `value`.
+ return v != value;
};
- }
+ if (reifyToFuncArgs) {
+ // Reify in terms of function block arguments.
+ stopCondition = [](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
+ auto bbArg = dyn_cast<BlockArgument>(v);
+ if (!bbArg)
+ return false;
+ return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp());
+ };
+ }
- // Reify value bound
- rewriter.setInsertionPointAfter(op);
- FailureOr<OpFoldResult> reified = failure();
- if (constant) {
- auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
- *boundType, value, dim, /*stopCondition=*/nullptr);
- if (succeeded(reifiedConst))
- reified =
- FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
- } else if (scalable) {
- unsigned vscaleMin = 0;
- unsigned vscaleMax = 0;
- if (auto attr = "vscale_min"; op->hasAttrOfType<IntegerAttr>(attr)) {
- vscaleMin = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
- } else {
- op->emitOpError("expected `vscale_min` to be provided");
- return WalkResult::skip();
+ // Reify value bound
+ rewriter.setInsertionPointAfter(op);
+ FailureOr<OpFoldResult> reified = failure();
+ if (constant) {
+ auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
+ boundType, value, dim, /*stopCondition=*/nullptr);
+ if (succeeded(reifiedConst))
+ reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
+ } else if (scalable) {
+ auto loc = op->getLoc();
+ auto reifiedScalable =
+ vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+ value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType);
+ if (succeeded(reifiedScalable)) {
+ SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
+ if (reifiedScalable->map.getNumInputs() == 1) {
+ // The only possible input to the bound is vscale.
+ vscaleOperand.push_back(std::make_pair(
+ rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
}
- if (auto attr = "vscale_max"; op->hasAttrOfType<IntegerAttr>(attr)) {
- vscaleMax = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+ reified = affine::materializeComputedBound(
+ rewriter, loc, reifiedScalable->map, vscaleOperand);
+ }
+ } else {
+ if (dim) {
+ if (useArithOps) {
+ reified = arith::reifyShapedValueDimBound(
+ rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
} else {
- op->emitOpError("expected `vscale_max` to be provided");
- return WalkResult::skip();
- }
-
- auto loc = op->getLoc();
- auto reifiedScalable =
- vector::ScalableValueBoundsConstraintSet::computeScalableBound(
- value, dim, vscaleMin, vscaleMax, *boundType);
- if (succeeded(reifiedScalable)) {
- SmallVector<std::pair<Value, std::optional<int64_t>>, 1>
- vscaleOperand;
- if (reifiedScalable->map.getNumInputs() == 1) {
- // The only possible input to the bound is vscale.
- vscaleOperand.push_back(std::make_pair(
- rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
- }
- reified = affine::materializeComputedBound(
- rewriter, loc, reifiedScalable->map, vscaleOperand);
+ reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
+ value, *dim, stopCondition);
}
} else {
- if (dim) {
- if (useArithOps) {
- reified = arith::reifyShapedValueDimBound(
- rewriter, op->getLoc(), *boundType, value, *dim, stopCondition);
- } else {
- reified = reifyShapedValueDimBound(
- rewriter, op->getLoc(), *boundType, value, *dim, stopCondition);
- }
+ if (useArithOps) {
+ reified = arith::reifyIndexValueBound(
+ rewriter, op->getLoc(), boundType, value, stopCondition);
} else {
- if (useArithOps) {
- reified = arith::reifyIndexValueBound(
- rewriter, op->getLoc(), *boundType, value, stopCondition);
- } else {
- reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType,
- value, stopCondition);
- }
+ reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
+ value, stopCondition);
}
}
- if (failed(reified)) {
- op->emitOpError("could not reify bound");
- return WalkResult::interrupt();
- }
+ }
+ if (failed(reified)) {
+ op->emitOpError("could not reify bound");
+ return WalkResult::interrupt();
+ }
- // Replace the op with the reified bound.
- if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
- rewriter.replaceOp(op, val);
- return WalkResult::skip();
- }
- Value constOp = rewriter.create<arith::ConstantIndexOp>(
- op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
- rewriter.replaceOp(op, constOp);
+ // Replace the op with the reified bound.
+ if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
+ rewriter.replaceOp(op, val);
return WalkResult::skip();
}
- return WalkResult::advance();
+ Value constOp = rewriter.create<arith::ConstantIndexOp>(
+ op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
+ rewriter.replaceOp(op, constOp);
+ return WalkResult::skip();
});
return failure(result.wasInterrupted());
}
@@ -246,60 +182,42 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
/// Look for "test.compare" ops and emit errors/remarks.
static LogicalResult testEquality(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
- WalkResult result = funcOp.walk([&](Operation *op) {
- // Look for test.compare ops.
- if (op->getName().getStringRef() == "test.compare") {
- if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() ||
- !op->getOperand(1).getType().isIndex()) {
- op->emitOpError("invalid op");
- return WalkResult::skip();
- }
-
- // Get comparison operator.
- std::string cmpStr = "EQ";
- if (auto cmpAttr = op->getAttrOfType<StringAttr>("cmp"))
- cmpStr = cmpAttr.str();
- auto cmpType = parseComparisonOperator(cmpStr);
- if (failed(cmpType)) {
- op->emitOpError("invalid comparison operator");
+ WalkResult result = funcOp.walk([&](test::CompareOp op) {
+ auto cmpType = op.getComparisonOperator();
+ if (op.getCompose()) {
+ if (cmpType != ValueBoundsConstraintSet::EQ) {
+ op->emitOpError(
+ "comparison operator must be EQ when 'composed' is specified");
return WalkResult::interrupt();
}
-
- if (op->hasAttr("compose")) {
- if (cmpType != ValueBoundsConstraintSet::EQ) {
- op->emitOpError(
- "comparison operator must be EQ when 'composed' is specified");
- return WalkResult::interrupt();
- }
- FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
- op->getOperand(0), op->getOperand(1));
- if (failed(delta)) {
- op->emitError("could not determine equality");
- } else if (*delta == 0) {
- op->emitRemark("equal");
- } else {
- op->emitRemark("different");
- }
- return WalkResult::advance();
- }
-
- auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
- return ValueBoundsConstraintSet::compare(
- /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp,
- /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt);
- };
- if (compare(*cmpType)) {
- op->emitRemark("true");
- } else if (*cmpType != ValueBoundsConstraintSet::EQ &&
- compare(invertComparisonOperator(*cmpType))) {
- op->emitRemark("false");
- } else if (*cmpType == ValueBoundsConstraintSet::EQ &&
- (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
- compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
- op->emitRemark("false");
+ FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
+ op->getOperand(0), op->getOperand(1));
+ if (failed(delta)) {
+ op->emitError("could not determine equality");
+ } else if (*delta == 0) {
+ op->emitRemark("equal");
} else {
- op->emitError("unknown");
+ op->emitRemark("different");
}
+ return WalkResult::advance();
+ }
+
+ auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
+ return ValueBoundsConstraintSet::compare(
+ /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
+ /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+ };
+ if (compare(cmpType)) {
+ op->emitRemark("true");
+ } else if (cmpType != ValueBoundsConstraintSet::EQ &&
+ compare(invertComparisonOperator(cmpType))) {
+ op->emitRemark("false");
+ } else if (cmpType == ValueBoundsConstraintSet::EQ &&
+ (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
+ compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
+ op->emitRemark("false");
+ } else {
+ op->emitError("unknown");
}
return WalkResult::advance();
});
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 47ddcf6524748c..d246c0492a3bd5 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -85,6 +85,7 @@ add_mlir_library(MLIRTestDialect
MLIRTensorDialect
MLIRTransformUtils
MLIRTransforms
+ MLIRValueBoundsOpInterface
)
add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 74378633032ce7..25c5190ca0ef3a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -516,6 +516,54 @@ static void printOptionalCustomParser(AsmPrinter &p, Operation *,
p.printAttribute(result);
}
+//===----------------------------------------------------------------------===//
+// ReifyBoundOp
+//===----------------------------------------------------------------------===//
+
+::mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
+ if (getType() == "EQ")
+ return ::mlir::presburger::BoundType::EQ;
+ if (getType() == "LB")
+ return ::mlir::presburger::BoundType::LB;
+ if (getType() == "UB")
+ return ::mlir::presburger::BoundType::UB;
+ llvm_unreachable("invalid bound type");
+}
+
+LogicalResult ReifyBoundOp::verify() {
+ if (isa<ShapedType>(getVar().getType())) {
+ if (!getDim().has_value())
+ return emitOpError("expected 'dim' attribute for shaped type variable");
+ } else if (getVar().getType().isIndex()) {
+ if (getDim().has_value())
+ return emitOpError("unexpected 'dim' attribute for index variable");
+ } else {
+ return emitOpError("expected index-typed variable or shape type variable");
+ }
+ if (getConstant() && getScalable())
+ return emitOpError("'scalable' and 'constant' are mutually exlusive");
+ if (getScalable() != getVscaleMin().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ if (getScalable() != getVscaleMax().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ return success();
+}
+
+::mlir::ValueBoundsConstraintSet::ComparisonOperator
+CompareOp::getComparisonOperator() {
+ if (getCmp() == "EQ")
+ return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+ if (getCmp() == "LT")
+ return ValueBoundsConstraintSet::ComparisonOperator::LT;
+ if (getCmp() == "LE")
+ return ValueBoundsConstraintSet::ComparisonOperator::LE;
+ if (getCmp() == "GT")
+ return ValueBoundsConstraintSet::ComparisonOperator::GT;
+ if (getCmp() == "GE")
+ return ValueBoundsConstraintSet::ComparisonOperator::GE;
+ llvm_unreachable("invalid comparison operator");
+}
+
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 4ba28c47ed1c33..d5b2fbeafc4104 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -41,6 +41,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include <memory>
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c88f85b8b6b361..0ab2427038add3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2184,6 +2184,53 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
let results = (outs AnyRankedOrUnrankedMemRef:$result);
}
+//===----------------------------------------------------------------------===//
+// Test ValueBoundsOpInterface
+//===----------------------------------------------------------------------===//
+
+def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
+ let description = [{
+ Reify a bound for the given index-typed value or dimension size of a shaped
+ value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set, a
+ bound in terms of "vector.vscale" is computed. `vscale_min` and `vscale_max`
+ must be specified in that case.
+ }];
+
+ let arguments = (ins AnyType:$var,
+ OptionalAttr<I64Attr>:$dim,
+ DefaultValuedAttr<StrAttr, "\"EQ\"">:$type,
+ UnitAttr:$constant,
+ UnitAttr:$scalable,
+ OptionalAttr<I64Attr>:$vscale_min,
+ OptionalAttr<I64Attr>:$vscale_max);
+ let results = (outs Index:$result);
+
+ let extraClassDeclaration = [{
+ ::mlir::presburger::BoundType getBoundType();
+ }];
+
+ let hasVerifier = 1;
+}
+
+def CompareOp : TEST_Op<"compare"> {
+ let description = [{
+ Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
+ specified comparison operator was proven to hold. The remark also indicates
+ whether the opposite comparison operator was proven to hold.
+ }];
+
+ let arguments = (ins Index:$lhs,
+ Index:$rhs,
+ DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+ UnitAttr:$compose);
+ let results = (outs);
+
+ let extraClassDeclaration = [{
+ ::mlir::ValueBoundsConstraintSet::ComparisonOperator
+ getComparisonOperator();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Test RegionBranchOpInterface
//===----------------------------------------------------------------------===//
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 684b59e7f62f65..dc5f4047c286db 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -421,6 +421,7 @@ cc_library(
"//mlir:TranslateLib",
"//mlir:TransformUtils",
"//mlir:Transforms",
+ "//mlir:ValueBoundsOpInterface",
"//mlir:ViewLikeInterface",
],
)
@@ -644,6 +645,7 @@ cc_library(
"lib/Dialect/Affine/*.cpp",
]),
deps = [
+ ":TestDialect",
"//llvm:Support",
"//mlir:AffineAnalysis",
"//mlir:AffineDialect",
More information about the Mlir-commits
mailing list