[PATCH] D80281: [MLIR] Add `num_elements` to the shape dialect
Frederik Gossen via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Tue May 26 02:39:36 PDT 2020
frgossen updated this revision to Diff 266132.
frgossen added a comment.
Remove redundant documentation
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D80281/new/
https://reviews.llvm.org/D80281
Files:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Index: mlir/test/Dialect/Shape/canonicalize.mlir
===================================================================
--- mlir/test/Dialect/Shape/canonicalize.mlir
+++ mlir/test/Dialect/Shape/canonicalize.mlir
@@ -160,3 +160,25 @@
%cs = shape.index_to_size %ci
return %cs : !shape.size
}
+
+// -----
+// Fold number of elements computation.
+// CHECK-LABEL: func @num_elements
+func @num_elements() -> !shape.size {
+ // CHECK-NOT: shape.const_shape
+ %shape = shape.const_shape [4, 5, 6]
+ // CHECK-NOT: shape.num_elements
+ %num_elements = shape.num_elements %shape
+ // CHECK: %[[NUM:.*]] = shape.const_size 120
+ // CHECK-NEXT: return %[[NUM]] : !shape.size
+ return %num_elements : !shape.size
+}
+
+// -----
+// No folding.
+// CHECK-LABEL: func @nonfoldable_num_elements
+func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
+ // CHECK-NOT: shape.const_{{.*}}
+ %num_elements = shape.num_elements %shape
+ return %num_elements : !shape.size
+}
Index: mlir/lib/Dialect/Shape/IR/Shape.cpp
===================================================================
--- mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -318,6 +318,32 @@
return builder.getI64TensorAttr(extents);
}
+//===----------------------------------------------------------------------===//
+// NumElementsOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
+
+ // Fold only when argument constant.
+ Attribute shape = operands[0];
+ if (!shape)
+ return {};
+
+ APInt product(64, 1);
+ for (auto value : shape.cast<DenseIntElementsAttr>())
+ product *= value;
+ Builder builder(getContext());
+ return builder.getIndexAttr(product.getLimitedValue());
+}
+
+LogicalResult NumElementsOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(SizeType::get(context));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//
Index: mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
===================================================================
--- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -255,6 +255,24 @@
let results = (outs Shape_SizeType:$result);
}
+def Shape_NumElementsOp : Shape_Op<"num_elements", [
+ NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+ let summary = "Returns the number of elements for a given shape";
+ let description = [{
+ Returns the number of elements for a given shape which is the product of its
+ dimensions.
+ }];
+
+ let arguments = (ins Shape_ShapeType:$shape);
+ let results = (outs Shape_SizeType:$result);
+
+ let assemblyFormat = "attr-dict $shape";
+
+ let hasFolder = 1;
+}
+
def Shape_ReduceOp : Shape_Op<"reduce", []> {
let summary = "Returns an expression reduced over a shape";
let description = [{
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D80281.266132.patch
Type: text/x-patch
Size: 3237 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200526/acfab515/attachment.bin>
More information about the llvm-commits
mailing list