[Mlir-commits] [mlir] 5142448 - [MLIR][Shape] Refactor verification
Jacques Pienaar
llvmlistbot at llvm.org
Sat Jul 25 14:58:32 PDT 2020
Author: Jacques Pienaar
Date: 2020-07-25T14:55:19-07:00
New Revision: 5142448a5e2aeeffefb3aabdb48f19033025bc09
URL: https://github.com/llvm/llvm-project/commit/5142448a5e2aeeffefb3aabdb48f19033025bc09
DIFF: https://github.com/llvm/llvm-project/commit/5142448a5e2aeeffefb3aabdb48f19033025bc09.diff
LOG: [MLIR][Shape] Refactor verification
Based on https://reviews.llvm.org/D84439 but less restrictive, else we
don't allow shape_of to be able to produce a ranked output and doesn't
allow for iterative refinement here. We can consider making it more
restrictive later.
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 797dc0bc0cb6..8c32faee55f9 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -207,7 +207,7 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
let hasFolder = 1;
let hasCanonicalizer = 1;
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
@@ -252,7 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
}];
let hasFolder = 1;
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
}
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
@@ -325,7 +325,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
$lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict
}];
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let hasFolder = 1;
}
@@ -412,7 +412,7 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
- let verifier = [{ return ::verify(*this); }];
+ let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index d2b0dbdedb05..104ab46c5581 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -28,13 +28,37 @@ static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
}
-static bool isErrorPropagationPossible(ArrayRef<Type> operandTypes) {
+static bool isErrorPropagationPossible(TypeRange operandTypes) {
for (Type ty : operandTypes)
if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
return true;
return false;
}
+static LogicalResult verifySizeOrIndexOp(Operation *op) {
+ assert(op != nullptr && op->getNumResults() == 1);
+ Type resultTy = op->getResultTypes().front();
+ if (isErrorPropagationPossible(op->getOperandTypes())) {
+ if (!resultTy.isa<SizeType>())
+ return op->emitOpError()
+ << "if at least one of the operands can hold error values then "
+ "the result must be of type `size` to propagate them";
+ }
+ return success();
+}
+
+static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
+ assert(op != nullptr && op->getNumResults() == 1);
+ Type resultTy = op->getResultTypes().front();
+ if (isErrorPropagationPossible(op->getOperandTypes())) {
+ if (!resultTy.isa<ShapeType>())
+ return op->emitOpError()
+ << "if at least one of the operands can hold error values then "
+ "the result must be of type `shape` to propagate them";
+ }
+ return success();
+}
+
ShapeDialect::ShapeDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
@@ -542,23 +566,6 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
// GetExtentOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(GetExtentOp op) {
- Type shapeTy = op.shape().getType();
- Type dimTy = op.dim().getType();
- Type extentTy = op.extent().getType();
- if (isErrorPropagationPossible({shapeTy, dimTy})) {
- if (!extentTy.isa<SizeType>())
- op.emitError()
- << "if at least one of the operands can hold error values then the "
- "result must be of type `size` to propagate them";
- } else {
- if (extentTy.isa<SizeType>())
- op.emitError() << "if none of the operands can hold error values then "
- "the result must be of type `index`";
- }
- return success();
-}
-
Optional<int64_t> GetExtentOp::getConstantDim() {
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
return constSizeOp.value().getLimitedValue();
@@ -597,15 +604,6 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
// RankOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(shape::RankOp op) {
- if (op.shape().getType().isa<ShapeType>() &&
- !op.rank().getType().isa<SizeType>())
- return op.emitOpError()
- << "if operand is of type `shape` then the result must be of type "
- "`size` to propagate potential errors";
- return success();
-}
-
OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!shape)
@@ -680,21 +678,6 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
// MulOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(MulOp op) {
- Type resultTy = op.result().getType();
- if (isErrorPropagationPossible({op.lhs().getType(), op.rhs().getType()})) {
- if (!resultTy.isa<SizeType>())
- return op.emitOpError()
- << "if at least one of the operands can hold error values then "
- "the result must be of type `size` to propagate them";
- } else {
- if (resultTy.isa<SizeType>())
- return op.emitError() << "if none of the operands can hold error values "
- "then the result must be of type `index`";
- }
- return success();
-}
-
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
if (!lhs)
@@ -719,21 +702,6 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
return builder.getIndexTensorAttr(type.getShape());
}
-static LogicalResult verify(ShapeOfOp op) {
- Type resultTy = op.result().getType();
- if (isErrorPropagationPossible(op.arg().getType())) {
- if (!resultTy.isa<ShapeType>())
- return op.emitOpError()
- << "if operand is of type `value_shape` then the result must be "
- "of type `shape` to propagate potential error shapes";
- } else {
- if (resultTy != getExtentTensorType(op.getContext()))
- return op.emitOpError() << "if operand is a shaped type then the result "
- "must be an extent tensor";
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index b4900e491fb8..20f4e877a2a9 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -90,39 +90,21 @@ func @assuming_all_op_too_few_operands() {
func @shape_of(%value_arg : !shape.value_shape,
%shaped_arg : tensor<?x3x4xf32>) {
- // expected-error at +1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}}
+ // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
%0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
return
}
// -----
-func @shape_of(%value_arg : !shape.value_shape,
- %shaped_arg : tensor<?x3x4xf32>) {
- // expected-error at +1 {{if operand is a shaped type then the result must be an extent tensor}}
- %1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
- return
-}
-
-// -----
-
func @rank(%arg : !shape.shape) {
- // expected-error at +1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
+ // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%0 = shape.rank %arg : !shape.shape -> index
return
}
// -----
-func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
- %c0 = constant 0 : index
- // expected-error at +1 {{if none of the operands can hold error values then the result must be of type `index`}}
- %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
- return %result : !shape.size
-}
-
-// -----
-
func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
%c0 = shape.const_size 0
// expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
@@ -132,14 +114,6 @@ func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
// -----
-func @mul_error_free(%arg : index) -> !shape.size {
- // expected-error at +1 {{if none of the operands can hold error values then the result must be of type `index`}}
- %result = shape.mul %arg, %arg : index, index -> !shape.size
- return %result : !shape.size
-}
-
-// -----
-
func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
// expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.mul %lhs, %rhs : !shape.size, index -> index
More information about the Mlir-commits
mailing list