[Mlir-commits] [mlir] e4184c8 - [MLIR][Shape] Make dimension an operand of `get_extent`

Frederik Gossen llvmlistbot at llvm.org
Wed Jun 10 04:47:42 PDT 2020


Author: Frederik Gossen
Date: 2020-06-10T11:47:18Z
New Revision: e4184c84ca0662c73780e94763e83ec245b5a2b0

URL: https://github.com/llvm/llvm-project/commit/e4184c84ca0662c73780e94763e83ec245b5a2b0
DIFF: https://github.com/llvm/llvm-project/commit/e4184c84ca0662c73780e94763e83ec245b5a2b0.diff

LOG: [MLIR][Shape] Make dimension an operand of `get_extent`

The operation `get_extent` now accepts the dimension as an operand and is no
longer limited to constant dimensions.
A helper function facilitates the common constant use case.

Differential Revision: https://reviews.llvm.org/D81248

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2393aa1865a1..334a15ba2459 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -190,24 +190,26 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
   let summary = "Gets the specified extent from a shape";
   let description = [{
     Gets the extent indexed by `dim` from `shape`.
-
     If the shape is an error, it returns an error size.
   }];
   let arguments = (ins
     Shape_ShapeType:$shape,
-    Confined<I64Attr, [IntNonNegative]>:$dim
+    Shape_SizeType:$dim
   );
   let results = (outs Shape_SizeType:$extent);
   let assemblyFormat = "$shape `,` $dim attr-dict";
 
   let builders = [
-    // Builder that allows passing a simple integer instead of an IntegerAttr.
-    OpBuilder<
-      [{OpBuilder &builder, OperationState &result, Value shape, int64_t dim}],
-      [{build(builder, result, shape, builder.getI64IntegerAttr(dim));}]
-    >
+    // Builder that allows passing a constant dimension as a simple integer.
+    OpBuilder<"OpBuilder &builder, OperationState &result, Value shape, "
+              "int64_t dim">
   ];
 
+  let extraClassDeclaration = [{
+    /// Get the `dim` value as integer if it is constant.
+    Optional<int64_t> getConstantDim();
+  }];
+
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 9df20e4871c7..07913273b810 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -405,15 +405,31 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
 // GetExtentOp
 //===----------------------------------------------------------------------===//
 
+Optional<int64_t> GetExtentOp::getConstantDim() {
+  if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
+    return constSizeOp.value().getLimitedValue();
+  }
+  return llvm::None;
+}
+
 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
   if (!elements)
     return nullptr;
-  uint64_t dimToGet = dim().getLimitedValue();
-  // TODO: Constant fold this to some kind of constant error.
-  if (dimToGet >= (uint64_t)elements.getNumElements())
+  Optional<int64_t> dim = getConstantDim();
+  if (!dim.hasValue())
     return nullptr;
-  return elements.getValue({dimToGet});
+  if (dim.getValue() >= elements.getNumElements())
+    return nullptr;
+  return elements.getValue({(uint64_t)dim.getValue()});
+}
+
+void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
+                        int64_t dim) {
+  auto loc = result.location;
+  auto dimAttr = builder.getIndexAttr(dim);
+  Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
+  build(builder, result, shape, dimValue);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 7c90753e255e..6693e44d9b30 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,6 +1,5 @@
 // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s --dump-input=fail
 
-// -----
 // CHECK-LABEL: func @f
 func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
   // CHECK: shape.const_shape [2, 3, 4]
@@ -9,6 +8,7 @@ func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
 }
 
 // -----
+
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> (!shape.shape, !shape.shape) {
@@ -22,6 +22,7 @@ func @f() -> (!shape.shape, !shape.shape) {
 }
 
 // -----
+
 // Negative split point.
 // CHECK-LABEL: func @f
 func @f() -> (!shape.shape, !shape.shape) {
@@ -34,6 +35,7 @@ func @f() -> (!shape.shape, !shape.shape) {
 }
 
 // -----
+
 // Out of range split point. No folding.
 // CHECK-LABEL: func @f
 func @f() -> (!shape.shape, !shape.shape) {
@@ -45,6 +47,7 @@ func @f() -> (!shape.shape, !shape.shape) {
 }
 
 // -----
+
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {
@@ -56,6 +59,7 @@ func @f() -> !shape.shape {
 }
 
 // -----
+
 // Incompatible shapes. No folding.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {
@@ -67,6 +71,7 @@ func @f() -> !shape.shape {
 }
 
 // -----
+
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {
@@ -78,6 +83,7 @@ func @f() -> !shape.shape {
 }
 
 // -----
+
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> tensor<2xindex> {
@@ -88,6 +94,7 @@ func @f() -> tensor<2xindex> {
 }
 
 // -----
+
 // Basic case.
 // CHECK-LABEL: func @f()
 func @f() -> !shape.shape {
@@ -99,6 +106,8 @@ func @f() -> !shape.shape {
   return %ret : !shape.shape
 }
 
+// -----
+
 // CHECK-LABEL: func @no_fold
 func @no_fold(%arg0: index) -> !shape.shape {
   // CHECK-NOT: shape.const_shape
@@ -108,6 +117,7 @@ func @no_fold(%arg0: index) -> !shape.shape {
 }
 
 // -----
+
 // Cast constant size to index and fold it away.
 // CHECK-LABEL: func @const_size_to_index
 func @const_size_to_index() -> index {
@@ -119,6 +129,7 @@ func @const_size_to_index() -> index {
 }
 
 // -----
+
 // Cast constant index to size and fold it away.
 // CHECK-LABEL: func @const_index_to_size
 func @const_index_to_size() -> !shape.size {
@@ -130,6 +141,7 @@ func @const_index_to_size() -> !shape.size {
 }
 
 // -----
+
 // Cast constant index to size, then back, and fold it away.
 // CHECK-LABEL: func @const_index_to_size_to_index
 func @const_index_to_size_to_index() -> index {
@@ -143,6 +155,7 @@ func @const_index_to_size_to_index() -> index {
 }
 
 // -----
+
 // No folding.
 // CHECK-LABEL: func @nonfoldable_size_to_index
 func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
@@ -152,6 +165,7 @@ func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
 }
 
 // -----
+
 // No folding.
 // CHECK-LABEL: func @nonfoldable_index_to_size
 func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
@@ -161,6 +175,7 @@ func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
 }
 
 // -----
+
 // Fold number of elements computation.
 // CHECK-LABEL: func @num_elements
 func @num_elements() -> !shape.size {
@@ -174,6 +189,7 @@ func @num_elements() -> !shape.size {
 }
 
 // -----
+
 // No folding.
 // CHECK-LABEL: func @nonfoldable_num_elements
 func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
@@ -184,32 +200,37 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
 
 // -----
 
-// Canonicalization of shape.get_extent
-
 // Basic folding.
 // CHECK-LABEL: func @basic
 func @basic() -> !shape.size {
   // CHECK: shape.const_size 2
   %0 = shape.const_shape [0, 1, 2]
-  %1 = shape.get_extent %0, 2
+  %c2 = shape.const_size 2
+  %1 = shape.get_extent %0, %c2
   return %1 : !shape.size
 }
 
+// -----
+
 // Should not fold.
 // CHECK-LABEL: func @out_of_bounds
 func @out_of_bounds() -> !shape.size {
   // CHECK: shape.const_shape
   // CHECK: shape.get_extent
   %0 = shape.const_shape [0, 1, 2]
-  %1 = shape.get_extent %0, 3
+  %c3 = shape.const_size 3
+  %1 = shape.get_extent %0, %c3
   return %1 : !shape.size
 }
 
+// -----
+
 // Should not fold.
 // CHECK-LABEL: func @not_const
 func @not_const(%arg0: !shape.shape) -> !shape.size {
   // CHECK: shape.get_extent
-  %0 = shape.get_extent %arg0, 3
+  %c3 = shape.const_size 3
+  %0 = shape.get_extent %arg0, %c3
   return %0 : !shape.size
 }
 
@@ -270,6 +291,7 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
 }
 
 // -----
+
 // assuming_all with known passing witnesses can be folded
 // CHECK-LABEL: func @f
 func @f() {
@@ -285,6 +307,7 @@ func @f() {
 }
 
 // -----
+
 // assuming_all should not be removed if not all witnesses are statically passing.
 //
 // Additionally check that the attribute is moved to the end as this op is
@@ -303,6 +326,7 @@ func @f() {
 }
 
 // -----
+
 // any can be replaced with a constant input if it has one.
 // CHECK-LABEL: func @f
 func @f(%arg0 : !shape.shape) -> !shape.shape {
@@ -315,6 +339,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
 
 
 // -----
+
 // Folding of any with partially constant operands is not yet implemented.
 // CHECK-LABEL: func @f
 func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
@@ -325,6 +350,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
 }
 
 // -----
+
 // assuming with a known passing witness can be removed
 // CHECK-LABEL: func @f
 func @f() {
@@ -341,6 +367,7 @@ func @f() {
 }
 
 // -----
+
 // assuming without a known passing passing witness cannot be removed
 // CHECK-LABEL: func @f
 func @f() {


        


More information about the Mlir-commits mailing list