[Mlir-commits] [mlir] cf42b70 - [mlir][shape] Add `shape.get_extent`.

Sean Silva llvmlistbot at llvm.org
Tue May 26 17:06:42 PDT 2020


Author: Sean Silva
Date: 2020-05-26T17:03:40-07:00
New Revision: cf42b704391c44e84485dd2547ae006196998266

URL: https://github.com/llvm/llvm-project/commit/cf42b704391c44e84485dd2547ae006196998266
DIFF: https://github.com/llvm/llvm-project/commit/cf42b704391c44e84485dd2547ae006196998266.diff

LOG: [mlir][shape] Add `shape.get_extent`.

Summary:
This op extracts an extent from a shape.

This also is the first op which constant folds to shape.const_size,
which revealed that shape.const_size needs a folder (ConstantLike ops
seem to always need folders for the constant folding infra to work).

Differential Revision: https://reviews.llvm.org/D80394

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 7d62cebff8e6..0278d7bbeb06 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -131,6 +131,7 @@ def Shape_ConstSizeOp : Shape_Op<"const_size",
   let results = (outs Shape_SizeType:$result);
 
   let assemblyFormat = "attr-dict $value";
+  let hasFolder = 1;
 }
 
 def Shape_FromExtentsOp : Shape_Op<"from_extents", [
@@ -190,6 +191,37 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", []> {
   let hasFolder = 1;
 }
 
+def Shape_GetExtentOp : Shape_Op<"get_extent",
+    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Gets the specified extent from a shape";
+  let description = [{
+    Gets the extent indexed by `dim` from `shape`.
+
+    If the shape is an error, it returns an error size.
+  }];
+  let arguments = (ins
+    Shape_ShapeType:$shape,
+    Confined<I64Attr, [IntNonNegative]>:$dim
+  );
+  let results = (outs Shape_SizeType:$extent);
+  let assemblyFormat = "$shape `,` $dim attr-dict";
+
+  let builders = [
+    // Builder that allows passing a simple integer instead of an IntegerAttr.
+    OpBuilder<
+      [{
+        OpBuilder &builder, OperationState &result,
+        Value shape, int64_t dim
+      }],
+      [{
+        build(builder, result, shape, builder.getI64IntegerAttr(dim));
+      }]
+    >
+  ];
+
+  let hasFolder = 1;
+}
+
 def Shape_JoinOp : Shape_Op<"join", []> {
   let summary = "Returns the least general shape.size of its operands";
   let description = [{

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5c6a0c2204c3..095c41720fba 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -245,6 +245,8 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
   return success();
 }
 
+OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
+
 //===----------------------------------------------------------------------===//
 // FromExtentsOp
 //===----------------------------------------------------------------------===//
@@ -267,6 +269,37 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
   return builder.getI64TensorAttr(extents);
 }
 
+//===----------------------------------------------------------------------===//
+// GetExtentOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GetExtentOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+                              ValueRange operands, DictionaryAttr attributes,
+                              RegionRange regions,
+                              SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(SizeType::get(context));
+  return success();
+}
+
+OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
+  auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (!elements)
+    return nullptr;
+  uint64_t dimToGet = dim().getLimitedValue();
+  // TODO: Constant fold this to some kind of constant error.
+  if (dimToGet >= (uint64_t)elements.getNumElements())
+    return nullptr;
+  // This is a little inconvenient because getValue returns an IntegerAttr
+  // that is not of IndexType, but the result here needs to be of
+  // IndexType.
+  // TODO: Make ConstShapeOp hold an tensor of index instead of i64.
+  Builder builder(getContext());
+  return builder.getIntegerAttr(
+      builder.getIndexType(),
+      elements.getValue<IntegerAttr>({dimToGet}).getInt());
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeOfOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 2e35fc748d86..018f5b212b4e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -106,3 +106,33 @@ func @no_fold(%arg0: index) -> !shape.shape {
   %ret = shape.from_extents %e0, %arg0
   return %ret : !shape.shape
 }
+
+// -----
+// Canonicalization of shape.get_extent
+
+// Basic folding.
+// CHECK-LABEL: func @basic
+func @basic() -> !shape.size {
+  // CHECK: shape.const_size 2
+  %0 = shape.const_shape [0, 1, 2]
+  %1 = shape.get_extent %0, 2
+  return %1 : !shape.size
+}
+
+// Should not fold.
+// CHECK-LABEL: func @out_of_bounds
+func @out_of_bounds() -> !shape.size {
+  // CHECK: shape.const_shape
+  // CHECK: shape.get_extent
+  %0 = shape.const_shape [0, 1, 2]
+  %1 = shape.get_extent %0, 3
+  return %1 : !shape.size
+}
+
+// Should not fold.
+// CHECK-LABEL: func @not_const
+func @not_const(%arg0: !shape.shape) -> !shape.size {
+  // CHECK: shape.get_extent
+  %0 = shape.get_extent %arg0, 3
+  return %0 : !shape.size
+}


        


More information about the Mlir-commits mailing list