[Mlir-commits] [mlir] 569e4f9 - `shape` dialect: add some ops

Sean Silva llvmlistbot at llvm.org
Fri Mar 27 16:44:34 PDT 2020


Author: Sean Silva
Date: 2020-03-27T16:38:42-07:00
New Revision: 569e4f9bc99a755cc30f0102b29b1eefd4fa33b4

URL: https://github.com/llvm/llvm-project/commit/569e4f9bc99a755cc30f0102b29b1eefd4fa33b4
DIFF: https://github.com/llvm/llvm-project/commit/569e4f9bc99a755cc30f0102b29b1eefd4fa33b4.diff

LOG: `shape` dialect: add some ops

- add `to_extent_tensor`
 - rename `create_shape` to `from_extent_tensor` for symmetry
- add `split_at` and `concat` ops for basic shape manipulations

This set of ops is inspired by the requirements of lowering a dynamic-shape-aware batch matmul op. For such an op, the "matrix" dimensions aren't subject to broadcasting but the others are, and so we need to slice, broadcast, and reconstruct the final output shape. Furthermore, the actual broadcasting op used downstream uses a tensor of extents as its preferred shape interface for the actual op that does the broadcasting.

However, this functionality is quite general. It's obvious that `to_extent_tensor` is needed long-term to support many common patterns that involve computations on shapes. We can evolve the shape manipulation ops introduced here. The specific choices made here took into consideration the potentially unranked nature of the !shape.shape type, which means that a simple listing of dimensions to extract isn't possible in general.

Differential Revision: https://reviews.llvm.org/D76817

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/Shape.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 37ea17e2bfec..0134ba9381ac 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -16,6 +16,8 @@
 
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffects.h"
 
 namespace mlir {

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index da2b964ce6aa..3694212cb1cc 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,6 +14,7 @@
 #define SHAPE_OPS
 
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffects.td"
 
 // TODO(jpienaar): Move to base.
@@ -168,17 +169,37 @@ def Shape_ConstantOp : Shape_Op<"constant", []> {
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-def Shape_CreateShapeOp : Shape_Op<"create_shape", []> {
-  let summary = "Creates a shape descriptor from a tensor";
+def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
+  let summary = "Creates a shape from a tensor of extents";
   let description = [{
-    Creates a shape from a 1D integral tensor. The rank equals the number of
-    elements in the tensor, and extent matches the values of the elements.
+    Creates a shape from a 1D integral tensor of extents. The rank of the
+    resulting shape equals the number of elements in the tensor, and the
+    extents match the values of the elements.
   }];
 
   let arguments = (ins I32Tensor:$input);
   let results = (outs Shape_ShapeType:$result);
 }
 
+def Shape_ToExtentTensorOp : Shape_Op<"to_tensor", []> {
+  let summary = "Creates a dimension tensor from a shape";
+  // TODO: Think more about the error situation. Perhaps factor out the
+  // error detection into a separate op so downstream consumers can control
+  // their error behavior. Then this op would assume that the input has
+  // been properly checked to not be an error (and could thus be a
+  // NoSideEffect op).
+  let description = [{
+    Converts a shape to a 1D integral tensor of extents. The number of elements
+    in the tensor equals the rank of the shape, and the elements equal the
+    extents of the shape.
+
+    If the shape represents an error, then this op currently aborts the program.
+  }];
+
+  let arguments = (ins Shape_ShapeType:$input);
+  let results = (outs I32Tensor:$result);
+}
+
 def Shape_JoinOp : Shape_Op<"join", []> {
   let summary = "Returns the least general shape.size of its operands";
   let description = [{
@@ -299,4 +320,50 @@ def Shape_DebugPrintOp : Shape_Op<"debug_print", []> {
   let results =  (outs Shape_ShapeOrSizeType:$output);
 }
 
+def Shape_SplitAtOp : Shape_Op<"split_at",
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Splits a shape at a given index.";
+  let description = [{
+    Splits a shape at a given dimension `index`, returning two shapes.
+    If `index` is negative, it is treated as indexing from the back of the
+    shape. This negative-handling behavior is important when handling unranked
+    shapes, where the positive index is not necessarily knowable due to a
+    dynamic number of leading dimensions.
+
+    Examples:
+    - split_at([4,5,6], index=0) -> [], [4,5,6]
+    - split_at([4,5,6], index=1) -> [4], [5,6]
+    - split_at([4,5,6], index=2) -> [4,5], [6]
+    - split_at([4,5,6], index=3) -> [4,5,6], []
+    - split_at([4,5,6], index=4) -> error
+    - split_at([4,5,6], index=-1) -> [4,5], [6]
+    - split_at([4,5,6], index=-2) -> [4], [5,6]
+    - split_at([4,5,6], index=-3) -> [], [4,5,6]
+    - split_at([4,5,6], index=-4) -> error
+
+    Requires:
+    - `index` is in the range [-rank(operand),rank(operand)]
+  }];
+
+  let arguments = (ins Shape_ShapeType:$operand, I32:$index);
+  let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
+}
+
+def Shape_ConcatOp : Shape_Op<"concat",
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Concatenates two shapes.";
+  let description = [{
+    Creates a shape whose dimensions consist of first the dimensions from `lhs`
+    followed by the dimensions of `rhs`.
+
+    Example:
+    concat([2,3], [4,5]) -> [2,3,4,5]
+    concat([], []) -> []
+    concat([], [4,5,6]) -> [4,5,6]
+  }];
+
+  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
+  let results = (outs Shape_ShapeType:$result);
+}
+
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt
index 485a3c710abb..982f73385076 100644
--- a/mlir/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRShape
   )
 target_link_libraries(MLIRShape
   PUBLIC
+  MLIRInferTypeOpInterface
   MLIRIR
   MLIRSideEffects
   LLVMSupport

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f7f69a64826b..85798ce9eff4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -106,6 +106,33 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
 
 static LogicalResult verify(ConstantOp &op) { return success(); }
 
+//===----------------------------------------------------------------------===//
+// SplitAtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SplitAtOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  auto shapeType = ShapeType::get(context);
+  inferredReturnTypes.push_back(shapeType);
+  inferredReturnTypes.push_back(shapeType);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConcatOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  auto shapeType = ShapeType::get(context);
+  inferredReturnTypes.push_back(shapeType);
+  return success();
+}
+
 namespace mlir {
 namespace shape {
 


        


More information about the Mlir-commits mailing list