[Mlir-commits] [mlir] 9a65d68 - [mlir] Add target for Shape dialect
Jacques Pienaar
llvmlistbot at llvm.org
Tue Mar 17 14:54:37 PDT 2020
Author: Jacques Pienaar
Date: 2020-03-17T14:54:25-07:00
New Revision: 9a65d683e02a4c09451deafbe546ddcb24e32d27
URL: https://github.com/llvm/llvm-project/commit/9a65d683e02a4c09451deafbe546ddcb24e32d27
DIFF: https://github.com/llvm/llvm-project/commit/9a65d683e02a4c09451deafbe546ddcb24e32d27.diff
LOG: [mlir] Add target for Shape dialect
Summary:
Add targets and basic printing/parsing of types in Shape dialect.
Differential Revision: https://reviews.llvm.org/D76321
Added:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/ops.mlir
Modified:
mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/CMakeLists.txt
mlir/lib/Dialect/Shape/CMakeLists.txt
Removed:
mlir/lib/Dialect/Shape/DialectRegistration.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
index 702ec621f486..6f4d28c339c8 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
@@ -1 +1,5 @@
-add_mlir_dialect(ShapeOps shape ShapeOps)
+set(LLVM_TARGET_DEFINITIONS ShapeOps.td)
+mlir_tablegen(ShapeOps.h.inc -gen-op-decls)
+mlir_tablegen(ShapeOps.cpp.inc -gen-op-defs)
+mlir_tablegen(ShapeOpsDialect.h.inc -gen-dialect-decls)
+add_public_tablegen_target(MLIRShapeOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index fe302e62e9c8..37ea17e2bfec 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -27,7 +27,8 @@ enum Kind {
Element,
Shape,
Size,
- ValueShape
+ ValueShape,
+ LAST_SHAPE_TYPE = ValueShape
};
} // namespace ShapeTypes
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 105677fd08e1..da2b964ce6aa 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,6 +14,7 @@
#define SHAPE_OPS
include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffects.td"
// TODO(jpienaar): Move to base.
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
@@ -40,20 +41,24 @@ def ShapeDialect : Dialect {
let cppNamespace = "shape";
}
-def Shape_SizeType : DialectType<ShapeDialect,
- CPred<"$_self.isa<DimType>()">, "dim"> {
+def Shape_ComponentType : DialectType<ShapeDialect,
+ CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type"> {
let typeDescription = [{
- `shape.size` represents a non-negative integer with support for being
- unknown and invalid.
+ `shape.element_type` represents the element type of the ShapedType. It may
+ be unknown, error or regular element type supported by ShapedType.
+ }];
+}
- Operations on `shape.size` types are specialized to handle unknown/dynamic
- value. So, for example, `<unknown> + x == <unknown>` for all non-error `x :
- !shape.size` (e.g., an unknown value does not become known due to addition).
+def Shape_ElementType : DialectType<ShapeDialect,
+ CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type"> {
+ let typeDescription = [{
+ `shape.element_type` represents the element type of the ShapedType. It may
+ be unknown, error or regular element type supported by ShapedType.
}];
}
def Shape_ShapeType : DialectType<ShapeDialect,
- CPred<"$_self.isa<ShapeType>()">, "shape"> {
+ CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape"> {
let typeDescription = [{
`shape.type` represents either an unranked shape, a ranked shape with
possibly unknown dimensions or an invalid shape. The rank is of type
@@ -70,24 +75,20 @@ def Shape_ShapeType : DialectType<ShapeDialect,
}];
}
-def Shape_ElementType : DialectType<ShapeDialect,
- CPred<"$_self.isa<ElementType>()">, "element type"> {
+def Shape_SizeType : DialectType<ShapeDialect,
+ CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size"> {
let typeDescription = [{
- `shape.element_type` represents the element type of the ShapedType. It may
- be unknown, error or regular element type supported by ShapedType.
- }];
-}
+ `shape.size` represents a non-negative integer with support for being
+ unknown and invalid.
-def Shape_ComponentType : DialectType<ShapeDialect,
- CPred<"$_self.isa<ComponentType>()">, "component type"> {
- let typeDescription = [{
- `shape.element_type` represents the element type of the ShapedType. It may
- be unknown, error or regular element type supported by ShapedType.
+ Operations on `shape.size` types are specialized to handle unknown/dynamic
+ value. So, for example, `<unknown> + x == <unknown>` for all non-error `x :
+ !shape.size` (e.g., an unknown value does not become known due to addition).
}];
}
def Shape_ValueShapeType : DialectType<ShapeDialect,
- CPred<"$_self.isa<ValueShapeType>()">, "value shape"> {
+ CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape"> {
let typeDescription = [{
`shape.value_shape` represents the value produced by an operation (this
corresponds to `Value` in the compiler) and a shape. Conceptually this is a
@@ -116,8 +117,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> {
* lhs + rhs = (int)lhs + (int)rhs if known;
}];
- let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
- let results = (outs Shape_ShapeType:$result);
+ let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
+ let results = (outs Shape_SizeType:$result);
}
def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
@@ -158,8 +159,13 @@ def Shape_ConstantOp : Shape_Op<"constant", []> {
// TODO(jpienaar): Change to a more specialized attribute that would
// encapsulate the unknown parsing while using denser packing.
- let arguments = (ins ArrayAttr:$value);
+ let arguments = (ins AnyAttr:$value);
let results = (outs Shape_ShapeOrSizeType:$result);
+
+ // TODO: Move this to main so that all shape ops implement these.
+ let printer = [{ return ::print(p, *this); }];
+ let verifier = [{ return ::verify(*this); }];
+ let parser = [{ return ::parse$cppClass(parser, result); }];
}
def Shape_CreateShapeOp : Shape_Op<"create_shape", []> {
@@ -214,8 +220,8 @@ def Shape_MulOp : Shape_Op<"mul", [SameOperandsAndResultType]> {
- lhs * rhs = (int)lhs * (int)rhs if both known;
}];
- let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
- let results = (outs Shape_ShapeType:$result);
+ let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
+ let results = (outs Shape_SizeType:$result);
}
def Shape_ReduceOp : Shape_Op<"reduce", []> {
@@ -244,7 +250,7 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
%acc = "shape.mul"(%lci, %dim) :
(!shape.size, !shape.size) -> !shape.size
- "shape.return"(%acc) : (!shape.size) -> ()
+ shape.yield %acc : !shape.size
}) : (!shape.type, !shape.size) -> (!shape.size)
return %1 : !shape.size
}
@@ -266,6 +272,18 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
let results = (outs Shape_ShapeType:$result);
}
+def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
+ let summary = "Returns the value to parent op";
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
+ >];
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+
// TODO: Add Ops: if_static, if_ranked
// For testing usage.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 2d0099af23b8..f2615e8caccf 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -26,6 +26,7 @@
#include "mlir/Dialect/QuantOps/QuantOps.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Dialect.h"
@@ -50,6 +51,7 @@ inline void registerAllDialects() {
registerDialect<NVVM::NVVMDialect>();
registerDialect<ROCDL::ROCDLDialect>();
registerDialect<SDBMDialect>();
+ registerDialect<shape::ShapeDialect>();
return true;
}();
(void)init_once;
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 192616a1409f..2bb137f4795b 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -7,7 +7,7 @@ add_subdirectory(LoopOps)
add_subdirectory(OpenMP)
add_subdirectory(QuantOps)
add_subdirectory(SDBM)
-#add_subdirectory(Shape)
+add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(StandardOps)
add_subdirectory(VectorOps)
diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt
index de3ddfce8771..485a3c710abb 100644
--- a/mlir/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/CMakeLists.txt
@@ -1,12 +1,15 @@
-file(GLOB globbed *.c *.cpp)
add_mlir_dialect_library(MLIRShape
- ${globbed}
+ IR/Shape.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape
+
+ DEPENDS
+ MLIRShapeOpsIncGen
)
-add_dependencies(MLIRShape MLIRShapeOpsIncGen LLVMSupport)
target_link_libraries(MLIRShape
PUBLIC
+ MLIRIR
MLIRSideEffects
- LLVMSupport)
+ LLVMSupport
+ )
diff --git a/mlir/lib/Dialect/Shape/DialectRegistration.cpp b/mlir/lib/Dialect/Shape/DialectRegistration.cpp
deleted file mode 100644
index 818ac149453c..000000000000
--- a/mlir/lib/Dialect/Shape/DialectRegistration.cpp
+++ /dev/null
@@ -1,13 +0,0 @@
-//===- DialectRegistration.cpp - Register shape dialect -------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Shape/IR/Shape.h"
-using namespace mlir;
-
-// Static initialization for shape dialect registration.
-static DialectRegistration<shape::ShapeDialect> ShapeOps;
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
new file mode 100644
index 000000000000..f7f69a64826b
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -0,0 +1,116 @@
+//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Shape/IR/Shape.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::shape;
+
+ShapeDialect::ShapeDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
+ >();
+ addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
+ // Allow unknown operations during prototyping and testing. As the dialect is
+ // still evolving it makes it simple to start with an unregistered ops and
+ // try
diff erent variants before actually defining the op.
+ allowUnknownOperations();
+}
+
+/// Parse a type registered to this dialect.
+Type ShapeDialect::parseType(DialectAsmParser &parser) const {
+ StringRef keyword;
+ if (parser.parseKeyword(&keyword))
+ return Type();
+
+ if (keyword == "component")
+ return ComponentType::get(getContext());
+ if (keyword == "element")
+ return ElementType::get(getContext());
+ if (keyword == "shape")
+ return ShapeType::get(getContext());
+ if (keyword == "size")
+ return SizeType::get(getContext());
+ if (keyword == "value_shape")
+ return ValueShapeType::get(getContext());
+
+ parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
+ return Type();
+}
+
+/// Print a type registered to this dialect.
+void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
+ switch (type.getKind()) {
+ case ShapeTypes::Component:
+ os << "component";
+ return;
+ case ShapeTypes::Element:
+ os << "element";
+ return;
+ case ShapeTypes::Size:
+ os << "size";
+ return;
+ case ShapeTypes::Shape:
+ os << "shape";
+ return;
+ case ShapeTypes::ValueShape:
+ os << "value_shape";
+ return;
+ default:
+ llvm_unreachable("unexpected 'shape' type kind");
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Constant*Op
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ConstantOp &op) {
+ p << "shape.constant ";
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
+
+ if (op.getAttrs().size() > 1)
+ p << ' ';
+ p.printAttributeWithoutType(op.value());
+ p << " : " << op.getType();
+}
+
+static ParseResult parseConstantOp(OpAsmParser &parser,
+ OperationState &result) {
+ Attribute valueAttr;
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ Type i64Type = parser.getBuilder().getIntegerType(64);
+ if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes))
+ return failure();
+
+ Type type;
+ if (parser.parseColonType(type))
+ return failure();
+
+ // Add the attribute type to the list.
+ return parser.addTypeToList(type, result.types);
+}
+
+static LogicalResult verify(ConstantOp &op) { return success(); }
+
+namespace mlir {
+namespace shape {
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
+
+} // namespace shape
+} // namespace mlir
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
new file mode 100644
index 000000000000..0dc1c32894d2
--- /dev/null
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt -split-input-file %s | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: shape_num_elements
+func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
+ %0 = shape.constant 0 : !shape.size
+ %1 = "shape.reduce"(%shape, %0) ( {
+ ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
+ %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
+ "shape.yield"(%acc) : (!shape.size) -> ()
+ }) : (!shape.shape, !shape.size) -> (!shape.size)
+ return %1 : !shape.size
+}
+
+func @test_shape_num_elements_unknown() {
+ %0 = "shape.unknown_shape"() : () -> !shape.shape
+ %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
+ %2 = "shape.print"(%1) : (!shape.size) -> !shape.size
+ return
+}
+
+func @test_shape_num_elements_fixed() {
+ %0 = "shape.constant"() { value = [1, 57, 92] }: () -> !shape.shape
+ %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
+ %3 = "shape.print"(%1) : (!shape.size) -> !shape.size
+ return
+}
+
+func @test_broadcastable_fixed() {
+ %0 = "shape.constant"() { value = [10, 1, 57, 92] }: () -> !shape.shape
+ %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+ %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+ return
+}
+
+func @test_shape_any_fixed() {
+ %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+ %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+ %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+ return
+}
+
+func @test_shape_any_unknown() {
+ %0 = "shape.constant"() { value = [4, -1, 92] }: () -> !shape.shape
+ %1 = "shape.constant"() { value = [-1, 57, 92] }: () -> !shape.shape
+ %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+ return
+}
+
+func @test_shape_any_fixed_mismatch() {
+ %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+ %1 = "shape.constant"() { value = [2, 57, 92] }: () -> !shape.shape
+ %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
+ return
+}
More information about the Mlir-commits
mailing list