[Mlir-commits] [mlir] 21b0eff - [mlir][shape] Add `shape.from_extents`.

Sean Silva llvmlistbot at llvm.org
Tue May 19 14:26:23 PDT 2020


Author: Sean Silva
Date: 2020-05-19T14:26:08-07:00
New Revision: 21b0eff7738a0ca0b23c5481e67e33e583b1a378

URL: https://github.com/llvm/llvm-project/commit/21b0eff7738a0ca0b23c5481e67e33e583b1a378
DIFF: https://github.com/llvm/llvm-project/commit/21b0eff7738a0ca0b23c5481e67e33e583b1a378.diff

LOG: [mlir][shape] Add `shape.from_extents`.

Summary:
This is a basic op needed for creating shapes from SSA values
representing the extents.

Differential Revision: https://reviews.llvm.org/D79833

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 47825577921e..074a54f9e5ae 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -132,6 +132,30 @@ def Shape_ConstSizeOp : Shape_Op<"const_size",
   let assemblyFormat = "attr-dict $value";
 }
 
+def Shape_FromExtentsOp : Shape_Op<"from_extents", [
+    NoSideEffect,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>
+    ]> {
+  let summary = "Creates a shape from extents";
+  let description = [{
+    Creates a shape from multiple SSA values representing the extents of
+    the shape.
+
+    ```mlir
+    // Rank 2 shape.
+    %s0 = shape.from_extents %a, %b
+    // Rank 0 shape.
+    %s1 = shape.from_extents
+    ```
+  }];
+  let arguments = (ins Variadic<Index>:$extents);
+  let results = (outs Shape_ShapeType:$shape);
+
+  let assemblyFormat = "attr-dict $extents";
+
+  let hasFolder = 1;
+}
+
 def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
   let summary = "Creates a shape from a tensor of extents";
   let description = [{

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a66fa8a8128a..e1d1b3365699 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -201,6 +201,28 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// FromExtentsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult FromExtentsOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(ShapeType::get(context));
+  return success();
+}
+
+OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
+  if (llvm::any_of(operands, [](Attribute a) { return !a; }))
+    return nullptr;
+  SmallVector<int64_t, 6> extents;
+  for (auto attr : operands)
+    extents.push_back(attr.cast<IntegerAttr>().getInt());
+  Builder builder(getContext());
+  return builder.getI64TensorAttr(extents);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeOfOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index ee69f90553d9..2e35fc748d86 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,3 +86,23 @@ func @f() -> tensor<2xindex> {
   %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex>
   return %0 : tensor<2xindex>
 }
+
+// -----
+// Basic case.
+// CHECK-LABEL: func @f()
+func @f() -> !shape.shape {
+  // CHECK: shape.const_shape [3, 5, 11]
+  %e0 = constant 3 : index
+  %e1 = constant 5 : index
+  %e2 = constant 11 : index
+  %ret = shape.from_extents %e0, %e1, %e2
+  return %ret : !shape.shape
+}
+
+// CHECK-LABEL: func @no_fold
+func @no_fold(%arg0: index) -> !shape.shape {
+  // CHECK-NOT: shape.const_shape
+  %e0 = constant 3 : index
+  %ret = shape.from_extents %e0, %arg0
+  return %ret : !shape.shape
+}


        


More information about the Mlir-commits mailing list