[Mlir-commits] [mlir] 9a65d68 - [mlir] Add target for Shape dialect

Jacques Pienaar llvmlistbot at llvm.org
Tue Mar 17 14:54:37 PDT 2020


Author: Jacques Pienaar
Date: 2020-03-17T14:54:25-07:00
New Revision: 9a65d683e02a4c09451deafbe546ddcb24e32d27

URL: https://github.com/llvm/llvm-project/commit/9a65d683e02a4c09451deafbe546ddcb24e32d27
DIFF: https://github.com/llvm/llvm-project/commit/9a65d683e02a4c09451deafbe546ddcb24e32d27.diff

LOG: [mlir] Add target for Shape dialect

Summary:
Add targets and basic printing/parsing of types in Shape dialect.

Differential Revision: https://reviews.llvm.org/D76321

Added: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/ops.mlir

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/CMakeLists.txt
    mlir/lib/Dialect/Shape/CMakeLists.txt

Removed: 
    mlir/lib/Dialect/Shape/DialectRegistration.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
index 702ec621f486..6f4d28c339c8 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
@@ -1 +1,5 @@
-add_mlir_dialect(ShapeOps shape ShapeOps)
+set(LLVM_TARGET_DEFINITIONS ShapeOps.td)
+mlir_tablegen(ShapeOps.h.inc -gen-op-decls)
+mlir_tablegen(ShapeOps.cpp.inc -gen-op-defs)
+mlir_tablegen(ShapeOpsDialect.h.inc -gen-dialect-decls)
+add_public_tablegen_target(MLIRShapeOpsIncGen)

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index fe302e62e9c8..37ea17e2bfec 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -27,7 +27,8 @@ enum Kind {
   Element,
   Shape,
   Size,
-  ValueShape
+  ValueShape,
+  LAST_SHAPE_TYPE = ValueShape
 };
 } // namespace ShapeTypes
 

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 105677fd08e1..da2b964ce6aa 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,6 +14,7 @@
 #define SHAPE_OPS
 
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffects.td"
 
 // TODO(jpienaar): Move to base.
 def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
@@ -40,20 +41,24 @@ def ShapeDialect : Dialect {
   let cppNamespace = "shape";
 }
 
-def Shape_SizeType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<DimType>()">, "dim"> {
+def Shape_ComponentType : DialectType<ShapeDialect,
+    CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type"> {
   let typeDescription = [{
-    `shape.size` represents a non-negative integer with support for being
-    unknown and invalid.
+    `shape.element_type` represents the element type of the ShapedType. It may
+    be unknown, error or regular element type supported by ShapedType.
+  }];
+}
 
-    Operations on `shape.size` types are specialized to handle unknown/dynamic
-    value. So, for example, `<unknown> + x == <unknown>` for all non-error `x :
-    !shape.size` (e.g., an unknown value does not become known due to addition).
+def Shape_ElementType : DialectType<ShapeDialect,
+    CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type"> {
+  let typeDescription = [{
+    `shape.element_type` represents the element type of the ShapedType. It may
+    be unknown, error or regular element type supported by ShapedType.
   }];
 }
 
 def Shape_ShapeType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<ShapeType>()">, "shape"> {
+    CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape"> {
   let typeDescription = [{
     `shape.type` represents either an unranked shape, a ranked shape with
     possibly unknown dimensions or an invalid shape. The rank is of type
@@ -70,24 +75,20 @@ def Shape_ShapeType : DialectType<ShapeDialect,
   }];
 }
 
-def Shape_ElementType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<ElementType>()">, "element type"> {
+def Shape_SizeType : DialectType<ShapeDialect,
+    CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size"> {
   let typeDescription = [{
-    `shape.element_type` represents the element type of the ShapedType. It may
-    be unknown, error or regular element type supported by ShapedType.
-  }];
-}
+    `shape.size` represents a non-negative integer with support for being
+    unknown and invalid.
 
-def Shape_ComponentType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<ComponentType>()">, "component type"> {
-  let typeDescription = [{
-    `shape.element_type` represents the element type of the ShapedType. It may
-    be unknown, error or regular element type supported by ShapedType.
+    Operations on `shape.size` types are specialized to handle unknown/dynamic
+    value. So, for example, `<unknown> + x == <unknown>` for all non-error `x :
+    !shape.size` (e.g., an unknown value does not become known due to addition).
   }];
 }
 
 def Shape_ValueShapeType : DialectType<ShapeDialect,
-    CPred<"$_self.isa<ValueShapeType>()">, "value shape"> {
+    CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape"> {
   let typeDescription = [{
     `shape.value_shape` represents the value produced by an operation (this
     corresponds to `Value` in the compiler) and a shape. Conceptually this is a
@@ -116,8 +117,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> {
     * lhs + rhs = (int)lhs + (int)rhs if known;
   }];
 
-  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
-  let results = (outs Shape_ShapeType:$result);
+  let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
+  let results = (outs Shape_SizeType:$result);
 }
 
 def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
@@ -158,8 +159,13 @@ def Shape_ConstantOp : Shape_Op<"constant", []> {
 
   // TODO(jpienaar): Change to a more specialized attribute that would
   // encapsulate the unknown parsing while using denser packing.
-  let arguments = (ins ArrayAttr:$value);
+  let arguments = (ins AnyAttr:$value);
   let results = (outs Shape_ShapeOrSizeType:$result);
+
+  // TODO: Move this to main so that all shape ops implement these.
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
 def Shape_CreateShapeOp : Shape_Op<"create_shape", []> {
@@ -214,8 +220,8 @@ def Shape_MulOp : Shape_Op<"mul", [SameOperandsAndResultType]> {
     - lhs * rhs = (int)lhs * (int)rhs if both known;
   }];
 
-  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
-  let results = (outs Shape_ShapeType:$result);
+  let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
+  let results = (outs Shape_SizeType:$result);
 }
 
 def Shape_ReduceOp : Shape_Op<"reduce", []> {
@@ -244,7 +250,7 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
         ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
           %acc = "shape.mul"(%lci, %dim) :
             (!shape.size, !shape.size) -> !shape.size
-          "shape.return"(%acc) : (!shape.size) -> ()
+          shape.yield %acc : !shape.size
         }) : (!shape.type, !shape.size) -> (!shape.size)
       return %1 : !shape.size
     }
@@ -266,6 +272,18 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
   let results = (outs Shape_ShapeType:$result);
 }
 
+def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
+  let summary = "Returns the value to parent op";
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let builders = [OpBuilder<
+    "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+  >];
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+
 // TODO: Add Ops: if_static, if_ranked
 
 // For testing usage.

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 2d0099af23b8..f2615e8caccf 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/QuantOps/QuantOps.h"
 #include "mlir/Dialect/SDBM/SDBMDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/Dialect.h"
@@ -50,6 +51,7 @@ inline void registerAllDialects() {
     registerDialect<NVVM::NVVMDialect>();
     registerDialect<ROCDL::ROCDLDialect>();
     registerDialect<SDBMDialect>();
+    registerDialect<shape::ShapeDialect>();
     return true;
   }();
   (void)init_once;

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 192616a1409f..2bb137f4795b 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -7,7 +7,7 @@ add_subdirectory(LoopOps)
 add_subdirectory(OpenMP)
 add_subdirectory(QuantOps)
 add_subdirectory(SDBM)
-#add_subdirectory(Shape)
+add_subdirectory(Shape)
 add_subdirectory(SPIRV)
 add_subdirectory(StandardOps)
 add_subdirectory(VectorOps)

diff  --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt
index de3ddfce8771..485a3c710abb 100644
--- a/mlir/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/CMakeLists.txt
@@ -1,12 +1,15 @@
-file(GLOB globbed *.c *.cpp)
 add_mlir_dialect_library(MLIRShape
-  ${globbed}
+  IR/Shape.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape
+
+  DEPENDS
+  MLIRShapeOpsIncGen
   )
-add_dependencies(MLIRShape MLIRShapeOpsIncGen LLVMSupport)
 target_link_libraries(MLIRShape
   PUBLIC
+  MLIRIR
   MLIRSideEffects
-  LLVMSupport)
+  LLVMSupport
+  )

diff  --git a/mlir/lib/Dialect/Shape/DialectRegistration.cpp b/mlir/lib/Dialect/Shape/DialectRegistration.cpp
deleted file mode 100644
index 818ac149453c..000000000000
--- a/mlir/lib/Dialect/Shape/DialectRegistration.cpp
+++ /dev/null
@@ -1,13 +0,0 @@
-//===- DialectRegistration.cpp - Register shape dialect -------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Shape/IR/Shape.h"
-using namespace mlir;
-
-// Static initialization for shape dialect registration.
-static DialectRegistration<shape::ShapeDialect> ShapeOps;

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
new file mode 100644
index 000000000000..f7f69a64826b
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -0,0 +1,116 @@
+//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Shape/IR/Shape.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+ShapeDialect::ShapeDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
+      >();
+  addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
+  // Allow unknown operations during prototyping and testing. As the dialect is
+  // still evolving it makes it simple to start with an unregistered ops and
+  // try 
diff erent variants before actually defining the op.
+  allowUnknownOperations();
+}
+
+/// Parse a type registered to this dialect.
+Type ShapeDialect::parseType(DialectAsmParser &parser) const {
+  StringRef keyword;
+  if (parser.parseKeyword(&keyword))
+    return Type();
+
+  if (keyword == "component")
+    return ComponentType::get(getContext());
+  if (keyword == "element")
+    return ElementType::get(getContext());
+  if (keyword == "shape")
+    return ShapeType::get(getContext());
+  if (keyword == "size")
+    return SizeType::get(getContext());
+  if (keyword == "value_shape")
+    return ValueShapeType::get(getContext());
+
+  parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
+  return Type();
+}
+
+/// Print a type registered to this dialect.
+void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
+  switch (type.getKind()) {
+  case ShapeTypes::Component:
+    os << "component";
+    return;
+  case ShapeTypes::Element:
+    os << "element";
+    return;
+  case ShapeTypes::Size:
+    os << "size";
+    return;
+  case ShapeTypes::Shape:
+    os << "shape";
+    return;
+  case ShapeTypes::ValueShape:
+    os << "value_shape";
+    return;
+  default:
+    llvm_unreachable("unexpected 'shape' type kind");
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Constant*Op
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ConstantOp &op) {
+  p << "shape.constant ";
+  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
+
+  if (op.getAttrs().size() > 1)
+    p << ' ';
+  p.printAttributeWithoutType(op.value());
+  p << " : " << op.getType();
+}
+
+static ParseResult parseConstantOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  Attribute valueAttr;
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  Type i64Type = parser.getBuilder().getIntegerType(64);
+  if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes))
+    return failure();
+
+  Type type;
+  if (parser.parseColonType(type))
+    return failure();
+
+  // Add the attribute type to the list.
+  return parser.addTypeToList(type, result.types);
+}
+
+static LogicalResult verify(ConstantOp &op) { return success(); }
+
+namespace mlir {
+namespace shape {
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
+
+} // namespace shape
+} // namespace mlir

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
new file mode 100644
index 000000000000..0dc1c32894d2
--- /dev/null
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt -split-input-file %s | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: shape_num_elements
+func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
+  %0 = shape.constant 0 : !shape.size
+  %1 = "shape.reduce"(%shape, %0) ( {
+    ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
+      %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
+      "shape.yield"(%acc) : (!shape.size) -> ()
+    }) : (!shape.shape, !shape.size) -> (!shape.size)
+  return %1 : !shape.size
+}
+
+func @test_shape_num_elements_unknown() {
+  %0 = "shape.unknown_shape"() : () -> !shape.shape
+  %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
+  %2 = "shape.print"(%1) : (!shape.size) -> !shape.size
+  return
+}
+
+func @test_shape_num_elements_fixed() {
+  %0 = "shape.constant"() { value = [1, 57, 92] }: () -> !shape.shape
+  %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
+  %3 = "shape.print"(%1) : (!shape.size) -> !shape.size
+  return
+}
+
+func @test_broadcastable_fixed() {
+  %0 = "shape.constant"() { value = [10, 1, 57, 92] }: () -> !shape.shape
+  %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+  %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+  return
+}
+
+func @test_shape_any_fixed() {
+  %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+  %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+  %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+  return
+}
+
+func @test_shape_any_unknown() {
+  %0 = "shape.constant"() { value = [4, -1, 92] }: () -> !shape.shape
+  %1 = "shape.constant"() { value = [-1, 57, 92] }: () -> !shape.shape
+  %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+  return
+}
+
+func @test_shape_any_fixed_mismatch() {
+  %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+  %1 = "shape.constant"() { value = [2, 57, 92] }: () -> !shape.shape
+  %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+  return
+}


        


More information about the Mlir-commits mailing list