[Mlir-commits] [mlir] 24bf554 - Add type function for ConstShape op.

Jacques Pienaar llvmlistbot at llvm.org
Mon May 17 11:47:34 PDT 2021


Author: Jacques Pienaar
Date: 2021-05-17T11:47:19-07:00
New Revision: 24bf554b1059d1ee27040ea90fc046d75950e58d

URL: https://github.com/llvm/llvm-project/commit/24bf554b1059d1ee27040ea90fc046d75950e58d
DIFF: https://github.com/llvm/llvm-project/commit/24bf554b1059d1ee27040ea90fc046d75950e58d.diff

LOG: Add type function for ConstShape op.

- Enables inferring return type for ConstShape, takes into account valid return types;
- The compatible return type function could be reused, leaving that for next use refactoring;

Differential Revision: https://reviews.llvm.org/D102182

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a4000180f27d1..d415bb8b56225 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -93,7 +93,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
   let verifier = [{ return ::verify(*this); }];
 }
 
-def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
+def Shape_ConstShapeOp : Shape_Op<"const_shape",
+    [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Creates a constant shape or extent tensor";
   let description = [{
     Creates a constant shape or extent tensor. The individual extents are given
@@ -103,7 +104,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
     ```mlir
     %0 = shape.const_shape [] : !shape.shape
     %1 = shape.const_shape [1, 2, 3] : !shape.shape
-    %2 = shape.const_shape [4, 5, 6] : tensor<?xindex>
+    %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
     ```
   }];
   let arguments = (ins IndexElementsAttr:$shape);
@@ -114,6 +115,11 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
   let parser = [{ return ::parse$cppClass(parser, result); }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
+
+  let extraClassDeclaration = [{
+    // InferTypeOpInterface:
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 def Shape_ConstSizeOp : Shape_Op<"const_size", [

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ac67a62a0aef5..2c38710426696 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -771,6 +772,37 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<TensorCastConstShape>(context);
 }
 
+LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  Builder b(context);
+  auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
+  if (!shape)
+    return emitOptionalError(location, "missing shape attribute");
+  inferredReturnTypes.assign({RankedTensorType::get(
+      {static_cast<int64_t>(shape.size())}, b.getIndexType())});
+  return success();
+}
+
+bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
+                                                        TypeRange r) {
+  if (l.size() != 1 || r.size() != 1)
+    return false;
+
+  Type lhs = l.front();
+  Type rhs = r.front();
+
+  if (lhs == rhs)
+    return true;
+
+  if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
+    // Shape type is compatible with all other valid return types.
+    return true;
+
+  return succeeded(verifyCompatibleShapes(lhs, rhs));
+}
+
 //===----------------------------------------------------------------------===//
 // CstrBroadcastableOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d42f0fac4b5c2..c605e25b3873c 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -254,3 +254,13 @@ func @fn(%arg: !shape.shape) -> !shape.witness {
   %0 = shape.cstr_broadcastable %arg : !shape.shape
   return %0 : !shape.witness
 }
+
+// -----
+
+// Test that type inference flags the wrong return type.
+
+func @const_shape() {
+  // expected-error at +1 {{'tensor<3xindex>' are incompatible with return type(s) of operation 'tensor<2xindex>'}}
+  %0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
+  return
+}

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index b9ae301d55799..24e7c2a6a2559 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -36,6 +36,7 @@ func @test_shape_num_elements_unknown() {
 func @const_shape() {
   %0 = shape.const_shape [1, 2, 3] : !shape.shape
   %1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
+  %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
   return
 }
 


        


More information about the Mlir-commits mailing list