[Mlir-commits] [mlir] 6594d54 - [MLIR] Add `index_to_size` and `size_to_index` to the shape dialect
Frederik Gossen
llvmlistbot at llvm.org
Thu May 28 06:58:21 PDT 2020
Author: Frederik Gossen
Date: 2020-05-28T13:57:20Z
New Revision: 6594d54571ee5887f031555a7660b8d8e74194d3
URL: https://github.com/llvm/llvm-project/commit/6594d54571ee5887f031555a7660b8d8e74194d3
DIFF: https://github.com/llvm/llvm-project/commit/6594d54571ee5887f031555a7660b8d8e74194d3.diff
LOG: [MLIR] Add `index_to_size` and `size_to_index` to the shape dialect
Add the two conversion operations `index_to_size` and `size_to_index` to the
shape dialect.
This facilitates the conversion of index types between the shape and the
standard dialect.
Differential Revision: https://reviews.llvm.org/D80280
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index dddc4c3ea08c..57d1954a3199 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -214,6 +214,25 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
let hasFolder = 1;
}
+def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [
+ NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Converts a standard index to a shape size";
+ let description = [{
+ Converts a standard index to a `shape.size`.
+ This operation and its inverse, `size_to_index`, facilitate index conversion
+ between the standard and the shape dialect.
+ The behavior is undefined for negative indices.
+ }];
+
+ let arguments = (ins Index:$arg);
+ let results = (outs Shape_SizeType:$result);
+
+ let assemblyFormat = "attr-dict $arg";
+
+ let hasFolder = 1;
+}
+
def Shape_JoinOp : Shape_Op<"join", []> {
let summary = "Returns the least general shape.size of its operands";
let description = [{
@@ -312,6 +331,25 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
let hasFolder = 1;
}
+def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
+ NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Casts between index types of the shape and standard dialect";
+ let description = [{
+ Converts a `shape.size` to a standard index.
+ This operation and its inverse, `index_to_size`, facilitate index conversion
+ between the standard and the shape dialect.
+ The behavior is undefined for unknown and invalid arguments.
+ }];
+
+ let arguments = (ins Shape_SizeType:$arg);
+ let results = (outs Index:$result);
+
+ let assemblyFormat = "attr-dict $arg";
+
+ let hasFolder = 1;
+}
+
def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
let summary = "Returns the value to parent op";
@@ -523,7 +561,6 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
let assemblyFormat = "$inputs attr-dict";
}
-
// Canonicalization patterns.
#endif // SHAPE_OPS
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index fc8f9b23e1e4..a077948fdd31 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -249,7 +249,7 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
return success();
}
-OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
+OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
//===----------------------------------------------------------------------===//
// ConstSizeOp
@@ -266,6 +266,26 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
+//===----------------------------------------------------------------------===//
+// IndexToSizeOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
+ // Constant values of both types, `shape.size` and `index`, are represented as
+ // `IntegerAttr`s which makes constant folding simple.
+ if (Attribute arg = operands[0])
+ return arg;
+ return {};
+}
+
+LogicalResult IndexToSizeOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(SizeType::get(context));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FromExtentsOp
//===----------------------------------------------------------------------===//
@@ -333,6 +353,26 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
return builder.getIndexTensorAttr(type.getShape());
}
+//===----------------------------------------------------------------------===//
+// SizeToIndexOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
+ // Constant values of both types, `shape.size` and `index`, are represented as
+ // `IntegerAttr`s which makes constant folding simple.
+ if (Attribute arg = operands[0])
+ return arg;
+ return {};
+}
+
+LogicalResult SizeToIndexOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(IndexType::get(context));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SplitAtOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 23147e557a15..106171de6087 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -108,6 +108,60 @@ func @no_fold(%arg0: index) -> !shape.shape {
}
// -----
+// Cast constant size to index and fold it away.
+// CHECK-LABEL: func @const_size_to_index
+func @const_size_to_index() -> index {
+ // CHECK-NOT: shape.index_cast
+ %cs = shape.const_size 123
+ // CHECK: constant 123 : index
+ %ci = shape.size_to_index %cs
+ return %ci : index
+}
+
+// -----
+// Cast constant index to size and fold it away.
+// CHECK-LABEL: func @const_index_to_size
+func @const_index_to_size() -> !shape.size {
+ // CHECK-NOT: index_cast
+ %ci = constant 123 : index
+ // CHECK: shape.const_size 123
+ %cs = shape.index_to_size %ci
+ return %cs : !shape.size
+}
+
+// -----
+// Cast constant index to size, then back, and fold it away.
+// CHECK-LABEL: func @const_index_to_size_to_index
+func @const_index_to_size_to_index() -> index {
+ // CHECK-NOT: shape.index_cast
+ %ci0 = constant 123 : index
+ %cs0 = shape.index_to_size %ci0
+ // CHECK: %[[CI:.*]] = constant 123 : index
+ // CHECK-NEXT: return %[[CI]] : index
+ %ci1 = shape.size_to_index %cs0
+ return %ci1 : index
+}
+
+// -----
+// No folding.
+// CHECK-LABEL: func @nonfoldable_size_to_index
+func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
+ // CHECK: shape.size_to_index
+ %ci = shape.size_to_index %cs
+ return %ci : index
+}
+
+// -----
+// No folding.
+// CHECK-LABEL: func @nonfoldable_index_to_size
+func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
+ // CHECK: shape.index_to_size
+ %cs = shape.index_to_size %ci
+ return %cs : !shape.size
+}
+
+// -----
+
// Canonicalization of shape.get_extent
// Basic folding.
More information about the Mlir-commits
mailing list