[Mlir-commits] [mlir] d1ad267 - [shape] Basic constant folding.

Sean Silva llvmlistbot at llvm.org
Fri Apr 24 15:49:43 PDT 2020


Author: Sean Silva
Date: 2020-04-24T15:49:35-07:00
New Revision: d1ad267a56d33c4a6556c2947c8138123cefc08d

URL: https://github.com/llvm/llvm-project/commit/d1ad267a56d33c4a6556c2947c8138123cefc08d
DIFF: https://github.com/llvm/llvm-project/commit/d1ad267a56d33c4a6556c2947c8138123cefc08d.diff

LOG: [shape] Basic constant folding.

- Implement a first constant fold for shape.shape_of (more ops coming in subsequent patches)
- Implement the right builder interfaces for ShapeType and other types
- Splits shape.constant into shape.const_size and shape.const_shape which plays better with dyn_cast and building vs one polymorphic op.

Also, fix the RUN line in ops.mlir to properly verify round-tripping.

Added: 
    mlir/test/Dialect/Shape/canonicalize.mlir

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 795fb353768b..57f60068d7cd 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -40,10 +40,13 @@ def ShapeDialect : Dialect {
   }];
 
   let cppNamespace = "shape";
+
+  let hasConstantMaterializer = 1;
 }
 
 def Shape_ComponentType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type"> {
+    CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type">,
+    BuildableType<"$_builder.getType<::mlir::shape::ComponentType>()"> {
   let typeDescription = [{
     `shape.element_type` represents the element type of the ShapedType. It may
     be unknown, error or regular element type supported by ShapedType.
@@ -51,7 +54,8 @@ def Shape_ComponentType : DialectType<ShapeDialect,
 }
 
 def Shape_ElementType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type"> {
+    CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type">,
+    BuildableType<"$_builder.getType<::mlir::shape::ElementType>()"> {
   let typeDescription = [{
     `shape.element_type` represents the element type of the ShapedType. It may
     be unknown, error or regular element type supported by ShapedType.
@@ -59,7 +63,8 @@ def Shape_ElementType : DialectType<ShapeDialect,
 }
 
 def Shape_ShapeType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape"> {
+    CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape">,
+    BuildableType<"$_builder.getType<::mlir::shape::ShapeType>()"> {
   let typeDescription = [{
     `shape.type` represents either an unranked shape, a ranked shape with
     possibly unknown dimensions or an invalid shape. The rank is of type
@@ -77,7 +82,8 @@ def Shape_ShapeType : DialectType<ShapeDialect,
 }
 
 def Shape_SizeType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size"> {
+    CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size">,
+    BuildableType<"$_builder.getType<::mlir::shape::SizeType>()"> {
   let typeDescription = [{
     `shape.size` represents a non-negative integer with support for being
     unknown and invalid.
@@ -89,7 +95,9 @@ def Shape_SizeType : DialectType<ShapeDialect,
 }
 
 def Shape_ValueShapeType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape"> {
+    CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape">,
+    BuildableType<"::mlir::shape::ValueShapeType::get($_builder.getContext())">
+{
   let typeDescription = [{
     `shape.value_shape` represents the value produced by an operation (this
     corresponds to `Value` in the compiler) and a shape. Conceptually this is a
@@ -146,27 +154,46 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
   let results = (outs Shape_ShapeType:$result);
 }
 
-def Shape_ConstantOp : Shape_Op<"constant", []> {
-  let summary = "Creates a shape constant";
+def Shape_ConstShapeOp : Shape_Op<"const_shape",
+    [ConstantLike,
+     NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Creates a constant of !shape.shape type.";
   let description = [{
-    An operation that builds a size or shape from integer or array attribute.
-    It allows for creating dynamically valued shapes by using `?` for unknown
-    values. A constant shape specified with `*` will return an unranked shape.
+    Creates a !shape.shape with rank given by the length of `shape` and with
+    dimension sizes given by the values of `shape`.
 
     ```mlir
-    %x = shape.constant 10 : !shape.size
+    %0 = shape.const_shape []
+    %1 = shape.const_shape [1, 2, 3]
     ```
   }];
-
-  // TODO(jpienaar): Change to a more specialized attribute that would
-  // encapsulate the unknown parsing while using denser packing.
-  let arguments = (ins AnyAttr:$value);
-  let results = (outs Shape_ShapeOrSizeType:$result);
+  let arguments = (ins I64ElementsAttr:$shape);
+  let results = (outs Shape_ShapeType:$result);
 
   // TODO: Move this to main so that all shape ops implement these.
   let printer = [{ return ::print(p, *this); }];
-  let verifier = [{ return ::verify(*this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
+  let hasFolder = 1;
+}
+
+def Shape_ConstSizeOp : Shape_Op<"const_size",
+    [ConstantLike,
+     NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Creates a constant of !shape.size type.";
+  let description = [{
+    Creates a !shape.size type representing the constant size given by `value`.
+
+    ```mlir
+    %x = shape.const_size 10
+    ```
+  }];
+
+  let arguments = (ins IndexAttr:$value);
+  let results = (outs Shape_SizeType:$result);
+
+  let assemblyFormat = "attr-dict $value";
 }
 
 def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
@@ -291,6 +318,8 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
 
   let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
   let results = (outs Shape_ShapeType:$result);
+
+  let hasFolder = 1;
 }
 
 def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 85798ce9eff4..3d3c1a9b6454 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -29,6 +30,19 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
   allowUnknownOperations();
 }
 
+Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
+                                             Attribute value, Type type,
+                                             Location loc) {
+  if (auto shapeType = type.dyn_cast<ShapeType>()) {
+    return builder.create<ConstShapeOp>(loc, type,
+                                        value.cast<DenseIntElementsAttr>());
+  }
+  if (auto sizeType = type.dyn_cast<SizeType>()) {
+    return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
+  }
+  return nullptr;
+}
+
 /// Parse a type registered to this dialect.
 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
   StringRef keyword;
@@ -74,37 +88,79 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
 }
 
 //===----------------------------------------------------------------------===//
-// Constant*Op
+// ConstShapeOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, ConstantOp &op) {
-  p << "shape.constant ";
-  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
-
-  if (op.getAttrs().size() > 1)
-    p << ' ';
-  p.printAttributeWithoutType(op.value());
-  p << " : " << op.getType();
+static void print(OpAsmPrinter &p, ConstShapeOp &op) {
+  p << "shape.const_shape ";
+  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
+  p << "[";
+  interleaveComma(op.shape().getValues<int64_t>(), p,
+                  [&](int64_t i) { p << i; });
+  p << "]";
 }
 
-static ParseResult parseConstantOp(OpAsmParser &parser,
-                                   OperationState &result) {
-  Attribute valueAttr;
+static ParseResult parseConstShapeOp(OpAsmParser &parser,
+                                     OperationState &result) {
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
-  Type i64Type = parser.getBuilder().getIntegerType(64);
-  if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes))
+  // We piggy-back on ArrayAttr parsing, though we don't internally store the
+  // shape as an ArrayAttr.
+  // TODO: Implement custom parser and maybe make syntax a bit more concise.
+  Attribute extentsRaw;
+  SmallVector<NamedAttribute, 6> dummy;
+  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
     return failure();
-
-  Type type;
-  if (parser.parseColonType(type))
+  auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
+  if (!extentsArray)
     return failure();
+  SmallVector<int64_t, 6> ints;
+  for (Attribute extent : extentsArray) {
+    IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
+    if (!attr)
+      return failure();
+    ints.push_back(attr.getInt());
+  }
+  Builder &builder = parser.getBuilder();
+  result.addAttribute("shape", builder.getI64TensorAttr(ints));
+
+  result.types.push_back(ShapeType::get(builder.getContext()));
+  return success();
+}
+
+OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
+
+LogicalResult ConstShapeOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(ShapeType::get(context));
+  return success();
+}
 
-  // Add the attribute type to the list.
-  return parser.addTypeToList(type, result.types);
+//===----------------------------------------------------------------------===//
+// ConstSizeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConstSizeOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(SizeType::get(context));
+  return success();
 }
 
-static LogicalResult verify(ConstantOp &op) { return success(); }
+//===----------------------------------------------------------------------===//
+// ShapeOfOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
+  auto type = getOperand().getType().dyn_cast<ShapedType>();
+  if (!type || !type.hasStaticShape())
+    return nullptr;
+  Builder builder(getContext());
+  return builder.getI64TensorAttr(type.getShape());
+}
 
 //===----------------------------------------------------------------------===//
 // SplitAtOp

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
new file mode 100644
index 000000000000..fad31d840523
--- /dev/null
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt -canonicalize <%s | FileCheck %s --dump-input=fail
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
+  // CHECK: shape.const_shape [2, 3, 4]
+  %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
+  return %0 : !shape.shape
+}

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 0dc1c32894d2..5ca3b0f49120 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt -split-input-file %s | FileCheck %s --dump-input-on-failure
+// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s --dump-input-on-failure
 
 // CHECK-LABEL: shape_num_elements
 func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
-  %0 = shape.constant 0 : !shape.size
+  %0 = shape.const_size 0
   %1 = "shape.reduce"(%shape, %0) ( {
     ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
       %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
@@ -19,40 +19,46 @@ func @test_shape_num_elements_unknown() {
 }
 
 func @test_shape_num_elements_fixed() {
-  %0 = "shape.constant"() { value = [1, 57, 92] }: () -> !shape.shape
+  %0 = shape.const_shape [1, 57, 92]
   %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
   %3 = "shape.print"(%1) : (!shape.size) -> !shape.size
   return
 }
 
 func @test_broadcastable_fixed() {
-  %0 = "shape.constant"() { value = [10, 1, 57, 92] }: () -> !shape.shape
-  %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+  %0 = shape.const_shape [10, 1, 57, 92]
+  %1 = shape.const_shape [4, 57, 92]
   %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
 func @test_shape_any_fixed() {
-  %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
-  %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+  %0 = shape.const_shape [4, 57, 92]
+  %1 = shape.const_shape [4, 57, 92]
   %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
 func @test_shape_any_unknown() {
-  %0 = "shape.constant"() { value = [4, -1, 92] }: () -> !shape.shape
-  %1 = "shape.constant"() { value = [-1, 57, 92] }: () -> !shape.shape
+  %0 = shape.const_shape [4, -1, 92]
+  %1 = shape.const_shape [-1, 57, 92]
   %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
 func @test_shape_any_fixed_mismatch() {
-  %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
-  %1 = "shape.constant"() { value = [2, 57, 92] }: () -> !shape.shape
+  %0 = shape.const_shape [4, 57, 92]
+  %1 = shape.const_shape [2, 57, 92]
   %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
+
+func @test_parse_const_shape() {
+  %0 = shape.const_shape []
+  %1 = shape.const_shape [1, 2, 3]
+  return
+}


        


More information about the Mlir-commits mailing list