[Mlir-commits] [mlir] 14d3cef - [MLIR][Shape] Generalze `shape.const_shape` to extent tensors
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 01:07:54 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T08:06:24Z
New Revision: 14d3cef01264c5575ada3ed7619e3e5b582ecbe4
URL: https://github.com/llvm/llvm-project/commit/14d3cef01264c5575ada3ed7619e3e5b582ecbe4
DIFF: https://github.com/llvm/llvm-project/commit/14d3cef01264c5575ada3ed7619e3e5b582ecbe4.diff
LOG: [MLIR][Shape] Generalze `shape.const_shape` to extent tensors
The operation `shape.const_shape` was used for constants of type shape only.
We can now also use it to create constant extent tensors.
Differential Revision: https://reviews.llvm.org/D84157
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 703353c35f5a..1bdfd9a071e3 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -96,18 +96,20 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
}
def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
- let summary = "Creates a constant of !shape.shape type";
+ let summary = "Creates a constant shape or extent tensor";
let description = [{
- Creates a !shape.shape with rank given by the length of `shape` and with
- dimension sizes given by the values of `shape`.
+ Creates a constant shape or extent tensor. The individual extents are given
+ as the `shape` attribute. The number of these values equals the shape's
+ rank.
```mlir
- %0 = shape.const_shape []
- %1 = shape.const_shape [1, 2, 3]
+ %0 = shape.const_shape [] : !shape.shape
+ %1 = shape.const_shape [1, 2, 3] : !shape.shape
+ %2 = shape.const_shape [4, 5, 6] : tensor<?xindex>
```
}];
let arguments = (ins IndexElementsAttr:$shape);
- let results = (outs Shape_ShapeType:$result);
+ let results = (outs Shape_ShapeOrExtentTensorType:$result);
// TODO: Move this to main so that all shape ops implement these.
let printer = [{ return ::print(p, *this); }];
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 92392b069a04..42b8b34c7e09 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -23,6 +23,11 @@ namespace {
#include "ShapeCanonicalization.inc"
}
+static RankedTensorType getExtentTensorType(OpBuilder &builder) {
+ return RankedTensorType::get({ShapedType::kDynamicSize},
+ builder.getIndexType());
+}
+
ShapeDialect::ShapeDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
@@ -40,12 +45,12 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- if (auto shapeType = type.dyn_cast<ShapeType>())
+ if (type.isa<ShapeType>() || type == getExtentTensorType(builder))
return builder.create<ConstShapeOp>(loc, type,
value.cast<DenseIntElementsAttr>());
- if (auto sizeType = type.dyn_cast<SizeType>())
+ if (type.isa<SizeType>())
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
- if (auto witnessType = type.dyn_cast<WitnessType>())
+ if (type.isa<WitnessType>())
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
return nullptr;
}
@@ -290,7 +295,8 @@ static void print(OpAsmPrinter &p, ConstShapeOp &op) {
p << "[";
interleaveComma(op.shape().getValues<int64_t>(), p,
[&](int64_t i) { p << i; });
- p << "]";
+ p << "] : ";
+ p.printType(op.getType());
}
static ParseResult parseConstShapeOp(OpAsmParser &parser,
@@ -316,8 +322,10 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
}
Builder &builder = parser.getBuilder();
result.addAttribute("shape", builder.getIndexTensorAttr(ints));
-
- result.types.push_back(ShapeType::get(builder.getContext()));
+ Type resultTy;
+ if (parser.parseColonType(resultTy))
+ return failure();
+ result.types.push_back(resultTy);
return success();
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 80b7cb9ddb94..20f21bbc877e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: func @f
func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
- // CHECK: shape.const_shape [2, 3, 4]
+ // CHECK: shape.const_shape [2, 3, 4] : !shape.shape
%0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
return %0 : !shape.shape
}
@@ -12,10 +12,10 @@ func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
// Basic case.
// CHECK-LABEL: func @f
func @f() -> (!shape.shape, !shape.shape) {
- // CHECK: shape.const_shape [2, 3]
- // CHECK: shape.const_shape [4, 5]
+ // CHECK: shape.const_shape [2, 3] : !shape.shape
+ // CHECK: shape.const_shape [4, 5] : !shape.shape
%c2 = constant 2 : i32
- %0 = shape.const_shape [2, 3, 4, 5]
+ %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
%head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
return %head, %tail : !shape.shape, !shape.shape
@@ -26,10 +26,10 @@ func @f() -> (!shape.shape, !shape.shape) {
// Negative split point.
// CHECK-LABEL: func @f
func @f() -> (!shape.shape, !shape.shape) {
- // CHECK: shape.const_shape [2, 3, 4]
- // CHECK: shape.const_shape [5]
+ // CHECK: shape.const_shape [2, 3, 4] : !shape.shape
+ // CHECK: shape.const_shape [5] : !shape.shape
%c-1 = constant -1 : i32
- %0 = shape.const_shape [2, 3, 4, 5]
+ %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
%head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
return %head, %tail : !shape.shape, !shape.shape
}
@@ -41,7 +41,7 @@ func @f() -> (!shape.shape, !shape.shape) {
func @f() -> (!shape.shape, !shape.shape) {
// CHECK: shape.split_at
%c5 = constant 5 : i32
- %0 = shape.const_shape [2, 3, 4, 5]
+ %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
%head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
return %head, %tail : !shape.shape, !shape.shape
}
@@ -51,9 +51,9 @@ func @f() -> (!shape.shape, !shape.shape) {
// Basic case.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
- // CHECK: shape.const_shape [7, 2]
- %0 = shape.const_shape [1, 2]
- %1 = shape.const_shape [7, 1]
+ // CHECK: shape.const_shape [7, 2] : !shape.shape
+ %0 = shape.const_shape [1, 2] : !shape.shape
+ %1 = shape.const_shape [7, 1] : !shape.shape
%2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@@ -64,7 +64,7 @@ func @f() -> !shape.shape {
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
- %0 = shape.const_shape []
+ %0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %arg0, %0
return %1 : !shape.shape
}
@@ -75,7 +75,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
- %0 = shape.const_shape []
+ %0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %0, %arg0
return %1 : !shape.shape
}
@@ -85,10 +85,10 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
// Lhs is a scalar and rhs is constant.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
- // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3]
+ // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3] : !shape.shape
// CHECK: return %[[CST]]
- %0 = shape.const_shape []
- %1 = shape.const_shape [1, 2, 3]
+ %0 = shape.const_shape [] : !shape.shape
+ %1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@@ -99,8 +99,8 @@ func @f() -> !shape.shape {
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
// CHECK: shape.broadcast
- %0 = shape.const_shape [2]
- %1 = shape.const_shape [7]
+ %0 = shape.const_shape [2] : !shape.shape
+ %1 = shape.const_shape [7] : !shape.shape
%2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@@ -110,9 +110,9 @@ func @f() -> !shape.shape {
// Basic case.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
- // CHECK: shape.const_shape [0, 1, 2, 3]
- %lhs = shape.const_shape [0, 1]
- %rhs = shape.const_shape [2, 3]
+ // CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape
+ %lhs = shape.const_shape [0, 1] : !shape.shape
+ %rhs = shape.const_shape [2, 3] : !shape.shape
%0 = shape.concat %lhs, %rhs
return %0 : !shape.shape
}
@@ -123,7 +123,7 @@ func @f() -> !shape.shape {
// CHECK-LABEL: func @f
func @f() -> tensor<2xindex> {
// CHECK: constant dense<[0, 1]> : tensor<2xindex>
- %cs = shape.const_shape [0, 1]
+ %cs = shape.const_shape [0, 1] : !shape.shape
%0 = shape.to_extent_tensor %cs : tensor<2xindex>
return %0 : tensor<2xindex>
}
@@ -133,7 +133,7 @@ func @f() -> tensor<2xindex> {
// Basic case.
// CHECK-LABEL: func @f()
func @f() -> !shape.shape {
- // CHECK: shape.const_shape [3, 5, 11]
+ // CHECK: shape.const_shape [3, 5, 11] : !shape.shape
%e0 = constant 3 : index
%e1 = constant 5 : index
%e2 = constant 11 : index
@@ -215,7 +215,7 @@ func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
// CHECK-LABEL: func @num_elements
func @num_elements() -> !shape.size {
// CHECK-NOT: shape.const_shape
- %shape = shape.const_shape [4, 5, 6]
+ %shape = shape.const_shape [4, 5, 6] : !shape.shape
// CHECK-NOT: shape.num_elements
%num_elements = shape.num_elements %shape
// CHECK: %[[NUM:.*]] = shape.const_size 120
@@ -239,7 +239,7 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// CHECK-LABEL: func @basic
func @basic() -> !shape.size {
// CHECK: shape.const_size 2
- %0 = shape.const_shape [0, 1, 2]
+ %0 = shape.const_shape [0, 1, 2] : !shape.shape
%c2 = shape.const_size 2
%1 = shape.get_extent %0, %c2
return %1 : !shape.size
@@ -252,7 +252,7 @@ func @basic() -> !shape.size {
func @out_of_bounds() -> !shape.size {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
- %0 = shape.const_shape [0, 1, 2]
+ %0 = shape.const_shape [0, 1, 2] : !shape.shape
%c3 = shape.const_size 3
%1 = shape.get_extent %0, %c3
return %1 : !shape.size
@@ -289,9 +289,9 @@ func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %cs0 = shape.const_shape [0, 1]
- %cs1 = shape.const_shape [0, 1]
- %cs2 = shape.const_shape [0, 1]
+ %cs0 = shape.const_shape [0, 1] : !shape.shape
+ %cs1 = shape.const_shape [0, 1] : !shape.shape
+ %cs2 = shape.const_shape [0, 1] : !shape.shape
%0 = shape.cstr_eq %cs0, %cs1, %cs2
"consume.witness"(%0) : (!shape.witness) -> ()
return
@@ -306,8 +306,8 @@ func @f() {
// CHECK-NEXT: shape.cstr_eq
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %cs0 = shape.const_shape [0, 1]
- %cs1 = shape.const_shape [3, 1]
+ %cs0 = shape.const_shape [0, 1] : !shape.shape
+ %cs1 = shape.const_shape [3, 1] : !shape.shape
%0 = shape.cstr_eq %cs0, %cs1
"consume.witness"(%0) : (!shape.witness) -> ()
return
@@ -367,7 +367,7 @@ func @f() {
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
// CHECK-NEXT: return %[[CS]]
- %0 = shape.const_shape [2, 3, 4]
+ %0 = shape.const_shape [2, 3, 4] : !shape.shape
%1 = shape.any %0, %arg0
return %1 : !shape.shape
}
@@ -429,8 +429,8 @@ func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %cs0 = shape.const_shape [3, 1]
- %cs1 = shape.const_shape [1, 5]
+ %cs0 = shape.const_shape [3, 1] : !shape.shape
+ %cs1 = shape.const_shape [1, 5] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
@@ -445,8 +445,8 @@ func @static_non_broadcastable() {
// CHECK-NEXT: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %cs0 = shape.const_shape [1, 3]
- %cs1 = shape.const_shape [1, 5]
+ %cs0 = shape.const_shape [1, 3] : !shape.shape
+ %cs1 = shape.const_shape [1, 5] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
@@ -460,7 +460,7 @@ func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %cs0 = shape.const_shape [1,3]
+ %cs0 = shape.const_shape [1, 3] : !shape.shape
%0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
@@ -498,7 +498,7 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
func @fold_rank() -> !shape.size {
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
// CHECK-DAG: return %[[RESULT]] : !shape.size
- %shape = shape.const_shape [3, 4, 5, 6, 7]
+ %shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape
%rank = shape.rank %shape : !shape.shape
return %rank : !shape.size
}
@@ -571,7 +571,7 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %0 = shape.const_shape []
+ %0 = shape.const_shape [] : !shape.shape
%1 = shape.shape_of %arg0 : tensor<?xf32>
%2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
"consume.witness"(%2) : (!shape.witness) -> ()
@@ -617,9 +617,9 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
func @shape_eq_fold_1() -> i1 {
// CHECK: %[[RESULT:.*]] = constant true
// CHECK: return %[[RESULT]] : i1
- %a = shape.const_shape [1, 2, 3]
- %b = shape.const_shape [1, 2, 3]
- %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+ %a = shape.const_shape [1, 2, 3] : !shape.shape
+ %b = shape.const_shape [1, 2, 3] : tensor<?xindex>
+ %result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
return %result : i1
}
@@ -630,9 +630,9 @@ func @shape_eq_fold_1() -> i1 {
func @shape_eq_fold_0() -> i1 {
// CHECK: %[[RESULT:.*]] = constant false
// CHECK: return %[[RESULT]] : i1
- %a = shape.const_shape [1, 2, 3]
- %b = shape.const_shape [4, 5, 6]
- %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+ %a = shape.const_shape [1, 2, 3] : tensor<?xindex>
+ %b = shape.const_shape [4, 5, 6] : tensor<?xindex>
+ %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
return %result : i1
}
@@ -643,8 +643,8 @@ func @shape_eq_fold_0() -> i1 {
func @shape_eq_fold_0() -> i1 {
// CHECK: %[[RESULT:.*]] = constant false
// CHECK: return %[[RESULT]] : i1
- %a = shape.const_shape [1, 2, 3, 4, 5, 6]
- %b = shape.const_shape [1, 2, 3]
+ %a = shape.const_shape [1, 2, 3, 4, 5, 6] : !shape.shape
+ %b = shape.const_shape [1, 2, 3] : !shape.shape
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
}
@@ -658,7 +658,7 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
// CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6]
// CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape
// CHECK: return %[[RESULT]] : i1
- %b = shape.const_shape [4, 5, 6]
+ %b = shape.const_shape [4, 5, 6] : !shape.shape
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 1187d7ad92bb..aace26de0ea2 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -33,48 +33,55 @@ func @test_shape_num_elements_unknown() {
return
}
+func @const_shape() {
+ %0 = shape.const_shape [1, 2, 3] : !shape.shape
+ %1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
+ return
+}
+
func @test_shape_num_elements_fixed() {
- %0 = shape.const_shape [1, 57, 92]
+ %0 = shape.const_shape [1, 57, 92] : !shape.shape
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
return
}
func @test_broadcast_fixed() {
- %0 = shape.const_shape [10, 1, 57, 92]
- %1 = shape.const_shape [4, 57, 92]
+ %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
+ %1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = shape.broadcast %0, %1
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed() {
- %0 = shape.const_shape [4, 57, 92]
- %1 = shape.const_shape [4, 57, 92]
+ %0 = shape.const_shape [4, 57, 92] : !shape.shape
+ %1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_unknown() {
- %0 = shape.const_shape [4, -1, 92]
- %1 = shape.const_shape [-1, 57, 92]
+ %0 = shape.const_shape [4, -1, 92] : !shape.shape
+ %1 = shape.const_shape [-1, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed_mismatch() {
- %0 = shape.const_shape [4, 57, 92]
- %1 = shape.const_shape [2, 57, 92]
+ %0 = shape.const_shape [4, 57, 92] : !shape.shape
+ %1 = shape.const_shape [2, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_parse_const_shape() {
- %0 = shape.const_shape []
- %1 = shape.const_shape [1, 2, 3]
+ %0 = shape.const_shape [] : !shape.shape
+ %1 = shape.const_shape [1, 2, 3] : !shape.shape
+ %2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
return
}
@@ -84,8 +91,8 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
}
func @test_constraints() {
- %0 = shape.const_shape []
- %1 = shape.const_shape [1, 2, 3]
+ %0 = shape.const_shape [] : !shape.shape
+ %1 = shape.const_shape [1, 2, 3] : !shape.shape
%w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
%w1 = shape.cstr_eq %0, %1
%w2 = shape.const_witness true
More information about the Mlir-commits
mailing list