[Mlir-commits] [mlir] 24bf554 - Add type function for ConstShape op.
Jacques Pienaar
llvmlistbot at llvm.org
Mon May 17 11:47:34 PDT 2021
Author: Jacques Pienaar
Date: 2021-05-17T11:47:19-07:00
New Revision: 24bf554b1059d1ee27040ea90fc046d75950e58d
URL: https://github.com/llvm/llvm-project/commit/24bf554b1059d1ee27040ea90fc046d75950e58d
DIFF: https://github.com/llvm/llvm-project/commit/24bf554b1059d1ee27040ea90fc046d75950e58d.diff
LOG: Add type function for ConstShape op.
- Enables inferring return type for ConstShape, takes into account valid return types;
- The compatible return type function could be reused, leaving that for next use refactoring;
Differential Revision: https://reviews.llvm.org/D102182
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a4000180f27d1..d415bb8b56225 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -93,7 +93,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
let verifier = [{ return ::verify(*this); }];
}
-def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
+def Shape_ConstShapeOp : Shape_Op<"const_shape",
+ [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Creates a constant shape or extent tensor";
let description = [{
Creates a constant shape or extent tensor. The individual extents are given
@@ -103,7 +104,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
```mlir
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
- %2 = shape.const_shape [4, 5, 6] : tensor<?xindex>
+ %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
```
}];
let arguments = (ins IndexElementsAttr:$shape);
@@ -114,6 +115,11 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
+
+ let extraClassDeclaration = [{
+ // InferTypeOpInterface:
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
def Shape_ConstSizeOp : Shape_Op<"const_size", [
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ac67a62a0aef5..2c38710426696 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -771,6 +772,37 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<TensorCastConstShape>(context);
}
+LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ Builder b(context);
+ auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
+ if (!shape)
+ return emitOptionalError(location, "missing shape attribute");
+ inferredReturnTypes.assign({RankedTensorType::get(
+ {static_cast<int64_t>(shape.size())}, b.getIndexType())});
+ return success();
+}
+
+bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ if (l.size() != 1 || r.size() != 1)
+ return false;
+
+ Type lhs = l.front();
+ Type rhs = r.front();
+
+ if (lhs == rhs)
+ return true;
+
+ if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
+ // Shape type is compatible with all other valid return types.
+ return true;
+
+ return succeeded(verifyCompatibleShapes(lhs, rhs));
+}
+
//===----------------------------------------------------------------------===//
// CstrBroadcastableOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d42f0fac4b5c2..c605e25b3873c 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -254,3 +254,13 @@ func @fn(%arg: !shape.shape) -> !shape.witness {
%0 = shape.cstr_broadcastable %arg : !shape.shape
return %0 : !shape.witness
}
+
+// -----
+
+// Test that type inference flags the wrong return type.
+
+func @const_shape() {
+ // expected-error at +1 {{'tensor<3xindex>' are incompatible with return type(s) of operation 'tensor<2xindex>'}}
+ %0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
+ return
+}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index b9ae301d55799..24e7c2a6a2559 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -36,6 +36,7 @@ func @test_shape_num_elements_unknown() {
func @const_shape() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
+ %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
return
}
More information about the Mlir-commits
mailing list