[Mlir-commits] [mlir] 14d3cef - [MLIR][Shape] Generalze `shape.const_shape` to extent tensors

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 01:07:54 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T08:06:24Z
New Revision: 14d3cef01264c5575ada3ed7619e3e5b582ecbe4

URL: https://github.com/llvm/llvm-project/commit/14d3cef01264c5575ada3ed7619e3e5b582ecbe4
DIFF: https://github.com/llvm/llvm-project/commit/14d3cef01264c5575ada3ed7619e3e5b582ecbe4.diff

LOG: [MLIR][Shape] Generalze `shape.const_shape` to extent tensors

The operation `shape.const_shape` was used for constants of type shape only.
We can now also use it to create constant extent tensors.

Differential Revision: https://reviews.llvm.org/D84157

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 703353c35f5a..1bdfd9a071e3 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -96,18 +96,20 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
 }
 
 def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
-  let summary = "Creates a constant of !shape.shape type";
+  let summary = "Creates a constant shape or extent tensor";
   let description = [{
-    Creates a !shape.shape with rank given by the length of `shape` and with
-    dimension sizes given by the values of `shape`.
+    Creates a constant shape or extent tensor. The individual extents are given
+    as the `shape` attribute. The number of these values equals the shape's
+    rank.
 
     ```mlir
-    %0 = shape.const_shape []
-    %1 = shape.const_shape [1, 2, 3]
+    %0 = shape.const_shape [] : !shape.shape
+    %1 = shape.const_shape [1, 2, 3] : !shape.shape
+    %2 = shape.const_shape [4, 5, 6] : tensor<?xindex>
     ```
   }];
   let arguments = (ins IndexElementsAttr:$shape);
-  let results = (outs Shape_ShapeType:$result);
+  let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
   // TODO: Move this to main so that all shape ops implement these.
   let printer = [{ return ::print(p, *this); }];

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 92392b069a04..42b8b34c7e09 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -23,6 +23,11 @@ namespace {
 #include "ShapeCanonicalization.inc"
 }
 
+static RankedTensorType getExtentTensorType(OpBuilder &builder) {
+  return RankedTensorType::get({ShapedType::kDynamicSize},
+                               builder.getIndexType());
+}
+
 ShapeDialect::ShapeDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addOperations<
@@ -40,12 +45,12 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
                                              Attribute value, Type type,
                                              Location loc) {
-  if (auto shapeType = type.dyn_cast<ShapeType>())
+  if (type.isa<ShapeType>() || type == getExtentTensorType(builder))
     return builder.create<ConstShapeOp>(loc, type,
                                         value.cast<DenseIntElementsAttr>());
-  if (auto sizeType = type.dyn_cast<SizeType>())
+  if (type.isa<SizeType>())
     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
-  if (auto witnessType = type.dyn_cast<WitnessType>())
+  if (type.isa<WitnessType>())
     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
   return nullptr;
 }
@@ -290,7 +295,8 @@ static void print(OpAsmPrinter &p, ConstShapeOp &op) {
   p << "[";
   interleaveComma(op.shape().getValues<int64_t>(), p,
                   [&](int64_t i) { p << i; });
-  p << "]";
+  p << "] : ";
+  p.printType(op.getType());
 }
 
 static ParseResult parseConstShapeOp(OpAsmParser &parser,
@@ -316,8 +322,10 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
   }
   Builder &builder = parser.getBuilder();
   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
-
-  result.types.push_back(ShapeType::get(builder.getContext()));
+  Type resultTy;
+  if (parser.parseColonType(resultTy))
+    return failure();
+  result.types.push_back(resultTy);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 80b7cb9ddb94..20f21bbc877e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: func @f
 func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
-  // CHECK: shape.const_shape [2, 3, 4]
+  // CHECK: shape.const_shape [2, 3, 4] : !shape.shape
   %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
   return %0 : !shape.shape
 }
@@ -12,10 +12,10 @@ func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> (!shape.shape, !shape.shape) {
-  // CHECK: shape.const_shape [2, 3]
-  // CHECK: shape.const_shape [4, 5]
+  // CHECK: shape.const_shape [2, 3] : !shape.shape
+  // CHECK: shape.const_shape [4, 5] : !shape.shape
   %c2 = constant 2 : i32
-  %0 = shape.const_shape [2, 3, 4, 5]
+  %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
   %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
   return %head, %tail : !shape.shape, !shape.shape
 
@@ -26,10 +26,10 @@ func @f() -> (!shape.shape, !shape.shape) {
 // Negative split point.
 // CHECK-LABEL: func @f
 func @f() -> (!shape.shape, !shape.shape) {
-  // CHECK: shape.const_shape [2, 3, 4]
-  // CHECK: shape.const_shape [5]
+  // CHECK: shape.const_shape [2, 3, 4] : !shape.shape
+  // CHECK: shape.const_shape [5] : !shape.shape
   %c-1 = constant -1 : i32
-  %0 = shape.const_shape [2, 3, 4, 5]
+  %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
   %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
   return %head, %tail : !shape.shape, !shape.shape
 }
@@ -41,7 +41,7 @@ func @f() -> (!shape.shape, !shape.shape) {
 func @f() -> (!shape.shape, !shape.shape) {
   // CHECK: shape.split_at
   %c5 = constant 5 : i32
-  %0 = shape.const_shape [2, 3, 4, 5]
+  %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
   %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
   return %head, %tail : !shape.shape, !shape.shape
 }
@@ -51,9 +51,9 @@ func @f() -> (!shape.shape, !shape.shape) {
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {
-  // CHECK: shape.const_shape [7, 2]
-  %0 = shape.const_shape [1, 2]
-  %1 = shape.const_shape [7, 1]
+  // CHECK: shape.const_shape [7, 2] : !shape.shape
+  %0 = shape.const_shape [1, 2] : !shape.shape
+  %1 = shape.const_shape [7, 1] : !shape.shape
   %2 = shape.broadcast %0, %1
   return %2 : !shape.shape
 }
@@ -64,7 +64,7 @@ func @f() -> !shape.shape {
 // CHECK-LABEL: func @f
 func @f(%arg0 : !shape.shape) -> !shape.shape {
   // CHECK: return %arg0
-  %0 = shape.const_shape []
+  %0 = shape.const_shape [] : !shape.shape
   %1 = shape.broadcast %arg0, %0
   return %1 : !shape.shape
 }
@@ -75,7 +75,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
 // CHECK-LABEL: func @f
 func @f(%arg0 : !shape.shape) -> !shape.shape {
   // CHECK: return %arg0
-  %0 = shape.const_shape []
+  %0 = shape.const_shape [] : !shape.shape
   %1 = shape.broadcast %0, %arg0
   return %1 : !shape.shape
 }
@@ -85,10 +85,10 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
 // Lhs is a scalar and rhs is constant.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {
-  // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3]
+  // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3] : !shape.shape
   // CHECK: return %[[CST]]
-  %0 = shape.const_shape []
-  %1 = shape.const_shape [1, 2, 3]
+  %0 = shape.const_shape [] : !shape.shape
+  %1 = shape.const_shape [1, 2, 3] : !shape.shape
   %2 = shape.broadcast %0, %1
   return %2 : !shape.shape
 }
@@ -99,8 +99,8 @@ func @f() -> !shape.shape {
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {
   // CHECK: shape.broadcast
-  %0 = shape.const_shape [2]
-  %1 = shape.const_shape [7]
+  %0 = shape.const_shape [2] : !shape.shape
+  %1 = shape.const_shape [7] : !shape.shape
   %2 = shape.broadcast %0, %1
   return %2 : !shape.shape
 }
@@ -110,9 +110,9 @@ func @f() -> !shape.shape {
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {
-  // CHECK: shape.const_shape [0, 1, 2, 3]
-  %lhs = shape.const_shape [0, 1]
-  %rhs = shape.const_shape [2, 3]
+  // CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape
+  %lhs = shape.const_shape [0, 1] : !shape.shape
+  %rhs = shape.const_shape [2, 3] : !shape.shape
   %0 = shape.concat %lhs, %rhs
   return %0 : !shape.shape
 }
@@ -123,7 +123,7 @@ func @f() -> !shape.shape {
 // CHECK-LABEL: func @f
 func @f() -> tensor<2xindex> {
   // CHECK: constant dense<[0, 1]> : tensor<2xindex>
-  %cs = shape.const_shape [0, 1]
+  %cs = shape.const_shape [0, 1] : !shape.shape
   %0 = shape.to_extent_tensor %cs : tensor<2xindex>
   return %0 : tensor<2xindex>
 }
@@ -133,7 +133,7 @@ func @f() -> tensor<2xindex> {
 // Basic case.
 // CHECK-LABEL: func @f()
 func @f() -> !shape.shape {
-  // CHECK: shape.const_shape [3, 5, 11]
+  // CHECK: shape.const_shape [3, 5, 11] : !shape.shape
   %e0 = constant 3 : index
   %e1 = constant 5 : index
   %e2 = constant 11 : index
@@ -215,7 +215,7 @@ func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
 // CHECK-LABEL: func @num_elements
 func @num_elements() -> !shape.size {
   // CHECK-NOT: shape.const_shape
-  %shape = shape.const_shape [4, 5, 6]
+  %shape = shape.const_shape [4, 5, 6] : !shape.shape
   // CHECK-NOT: shape.num_elements
   %num_elements = shape.num_elements %shape
   // CHECK: %[[NUM:.*]] = shape.const_size 120
@@ -239,7 +239,7 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
 // CHECK-LABEL: func @basic
 func @basic() -> !shape.size {
   // CHECK: shape.const_size 2
-  %0 = shape.const_shape [0, 1, 2]
+  %0 = shape.const_shape [0, 1, 2] : !shape.shape
   %c2 = shape.const_size 2
   %1 = shape.get_extent %0, %c2
   return %1 : !shape.size
@@ -252,7 +252,7 @@ func @basic() -> !shape.size {
 func @out_of_bounds() -> !shape.size {
   // CHECK: shape.const_shape
   // CHECK: shape.get_extent
-  %0 = shape.const_shape [0, 1, 2]
+  %0 = shape.const_shape [0, 1, 2] : !shape.shape
   %c3 = shape.const_size 3
   %1 = shape.get_extent %0, %c3
   return %1 : !shape.size
@@ -289,9 +289,9 @@ func @f() {
   // CHECK-NEXT: shape.const_witness true
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %cs0 = shape.const_shape [0, 1]
-  %cs1 = shape.const_shape [0, 1]
-  %cs2 = shape.const_shape [0, 1]
+  %cs0 = shape.const_shape [0, 1] : !shape.shape
+  %cs1 = shape.const_shape [0, 1] : !shape.shape
+  %cs2 = shape.const_shape [0, 1] : !shape.shape
   %0 = shape.cstr_eq %cs0, %cs1, %cs2
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
@@ -306,8 +306,8 @@ func @f() {
   // CHECK-NEXT: shape.cstr_eq
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %cs0 = shape.const_shape [0, 1]
-  %cs1 = shape.const_shape [3, 1]
+  %cs0 = shape.const_shape [0, 1] : !shape.shape
+  %cs1 = shape.const_shape [3, 1] : !shape.shape
   %0 = shape.cstr_eq %cs0, %cs1
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
@@ -367,7 +367,7 @@ func @f() {
 func @f(%arg0 : !shape.shape) -> !shape.shape {
   // CHECK-NEXT: %[[CS:.*]] = shape.const_shape
   // CHECK-NEXT: return %[[CS]]
-  %0 = shape.const_shape [2, 3, 4]
+  %0 = shape.const_shape [2, 3, 4] : !shape.shape
   %1 = shape.any %0, %arg0
   return %1 : !shape.shape
 }
@@ -429,8 +429,8 @@ func @f() {
   // CHECK-NEXT: shape.const_witness true
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %cs0 = shape.const_shape [3, 1]
-  %cs1 = shape.const_shape [1, 5]
+  %cs0 = shape.const_shape [3, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, 5] : !shape.shape
   %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
@@ -445,8 +445,8 @@ func @static_non_broadcastable() {
   // CHECK-NEXT: shape.cstr_broadcastable
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %cs0 = shape.const_shape [1, 3]
-  %cs1 = shape.const_shape [1, 5]
+  %cs0 = shape.const_shape [1, 3] : !shape.shape
+  %cs1 = shape.const_shape [1, 5] : !shape.shape
   %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
@@ -460,7 +460,7 @@ func @f(%arg0 : !shape.shape) {
   // CHECK-NEXT: shape.cstr_broadcastable
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %cs0 = shape.const_shape [1,3]
+  %cs0 = shape.const_shape [1, 3] : !shape.shape
   %0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
@@ -498,7 +498,7 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
 func @fold_rank() -> !shape.size {
   // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
   // CHECK-DAG: return %[[RESULT]] : !shape.size
-  %shape = shape.const_shape [3, 4, 5, 6, 7]
+  %shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape
   %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
@@ -571,7 +571,7 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
   // CHECK-NEXT: shape.const_witness true
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %0 = shape.const_shape []
+  %0 = shape.const_shape [] : !shape.shape
   %1 = shape.shape_of %arg0 : tensor<?xf32>
   %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   "consume.witness"(%2) : (!shape.witness) -> ()
@@ -617,9 +617,9 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
 func @shape_eq_fold_1() -> i1 {
   // CHECK: %[[RESULT:.*]] = constant true
   // CHECK: return %[[RESULT]] : i1
-  %a = shape.const_shape [1, 2, 3]
-  %b = shape.const_shape [1, 2, 3]
-  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  %a = shape.const_shape [1, 2, 3] : !shape.shape
+  %b = shape.const_shape [1, 2, 3] : tensor<?xindex>
+  %result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
   return %result : i1
 }
 
@@ -630,9 +630,9 @@ func @shape_eq_fold_1() -> i1 {
 func @shape_eq_fold_0() -> i1 {
   // CHECK: %[[RESULT:.*]] = constant false
   // CHECK: return %[[RESULT]] : i1
-  %a = shape.const_shape [1, 2, 3]
-  %b = shape.const_shape [4, 5, 6]
-  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  %a = shape.const_shape [1, 2, 3] : tensor<?xindex>
+  %b = shape.const_shape [4, 5, 6] : tensor<?xindex>
+  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
   return %result : i1
 }
 
@@ -643,8 +643,8 @@ func @shape_eq_fold_0() -> i1 {
 func @shape_eq_fold_0() -> i1 {
   // CHECK: %[[RESULT:.*]] = constant false
   // CHECK: return %[[RESULT]] : i1
-  %a = shape.const_shape [1, 2, 3, 4, 5, 6]
-  %b = shape.const_shape [1, 2, 3]
+  %a = shape.const_shape [1, 2, 3, 4, 5, 6] : !shape.shape
+  %b = shape.const_shape [1, 2, 3] : !shape.shape
   %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
   return %result : i1
 }
@@ -658,7 +658,7 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
   // CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6]
   // CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape
   // CHECK: return %[[RESULT]] : i1
-  %b = shape.const_shape [4, 5, 6]
+  %b = shape.const_shape [4, 5, 6] : !shape.shape
   %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
   return %result : i1
 }

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 1187d7ad92bb..aace26de0ea2 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -33,48 +33,55 @@ func @test_shape_num_elements_unknown() {
   return
 }
 
+func @const_shape() {
+  %0 = shape.const_shape [1, 2, 3] : !shape.shape
+  %1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
+  return
+}
+
 func @test_shape_num_elements_fixed() {
-  %0 = shape.const_shape [1, 57, 92]
+  %0 = shape.const_shape [1, 57, 92] : !shape.shape
   %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
   %3 = "shape.print"(%1) : (!shape.size) -> !shape.size
   return
 }
 
 func @test_broadcast_fixed() {
-  %0 = shape.const_shape [10, 1, 57, 92]
-  %1 = shape.const_shape [4, 57, 92]
+  %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
+  %1 = shape.const_shape [4, 57, 92] : !shape.shape
   %2 = shape.broadcast %0, %1
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
 func @test_shape_any_fixed() {
-  %0 = shape.const_shape [4, 57, 92]
-  %1 = shape.const_shape [4, 57, 92]
+  %0 = shape.const_shape [4, 57, 92] : !shape.shape
+  %1 = shape.const_shape [4, 57, 92] : !shape.shape
   %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
 func @test_shape_any_unknown() {
-  %0 = shape.const_shape [4, -1, 92]
-  %1 = shape.const_shape [-1, 57, 92]
+  %0 = shape.const_shape [4, -1, 92] : !shape.shape
+  %1 = shape.const_shape [-1, 57, 92] : !shape.shape
   %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
 func @test_shape_any_fixed_mismatch() {
-  %0 = shape.const_shape [4, 57, 92]
-  %1 = shape.const_shape [2, 57, 92]
+  %0 = shape.const_shape [4, 57, 92] : !shape.shape
+  %1 = shape.const_shape [2, 57, 92] : !shape.shape
   %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
 func @test_parse_const_shape() {
-  %0 = shape.const_shape []
-  %1 = shape.const_shape [1, 2, 3]
+  %0 = shape.const_shape [] : !shape.shape
+  %1 = shape.const_shape [1, 2, 3] : !shape.shape
+  %2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
   return
 }
 
@@ -84,8 +91,8 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
 }
 
 func @test_constraints() {
-  %0 = shape.const_shape []
-  %1 = shape.const_shape [1, 2, 3]
+  %0 = shape.const_shape [] : !shape.shape
+  %1 = shape.const_shape [1, 2, 3] : !shape.shape
   %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   %w1 = shape.cstr_eq %0, %1
   %w2 = shape.const_witness true


        


More information about the Mlir-commits mailing list