[Mlir-commits] [mlir] 6594d54 - [MLIR] Add `index_to_size` and `size_to_index` to the shape dialect

Frederik Gossen llvmlistbot at llvm.org
Thu May 28 06:58:21 PDT 2020


Author: Frederik Gossen
Date: 2020-05-28T13:57:20Z
New Revision: 6594d54571ee5887f031555a7660b8d8e74194d3

URL: https://github.com/llvm/llvm-project/commit/6594d54571ee5887f031555a7660b8d8e74194d3
DIFF: https://github.com/llvm/llvm-project/commit/6594d54571ee5887f031555a7660b8d8e74194d3.diff

LOG: [MLIR] Add `index_to_size` and `size_to_index` to the shape dialect

Add the two conversion operations `index_to_size` and `size_to_index` to the
shape dialect.
This facilitates the conversion of index types between the shape and the
standard dialect.

Differential Revision: https://reviews.llvm.org/D80280

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index dddc4c3ea08c..57d1954a3199 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -214,6 +214,25 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
   let hasFolder = 1;
 }
 
+def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [
+    NoSideEffect,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Converts a standard index to a shape size";
+  let description = [{
+    Converts a standard index to a `shape.size`.
+    This operation and its inverse, `size_to_index`, facilitate index conversion
+    between the standard and the shape dialect.
+    The behavior is undefined for negative indices.
+  }];
+
+  let arguments = (ins Index:$arg);
+  let results = (outs Shape_SizeType:$result);
+
+  let assemblyFormat = "attr-dict $arg";
+
+  let hasFolder = 1;
+}
+
 def Shape_JoinOp : Shape_Op<"join", []> {
   let summary = "Returns the least general shape.size of its operands";
   let description = [{
@@ -312,6 +331,25 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   let hasFolder = 1;
 }
 
+def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
+    NoSideEffect,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Casts between index types of the shape and standard dialect";
+  let description = [{
+    Converts a `shape.size` to a standard index.
+    This operation and its inverse, `index_to_size`, facilitate index conversion
+    between the standard and the shape dialect.
+    The behavior is undefined for unknown and invalid arguments.
+  }];
+
+  let arguments = (ins Shape_SizeType:$arg);
+  let results = (outs Index:$result);
+
+  let assemblyFormat = "attr-dict $arg";
+
+  let hasFolder = 1;
+}
+
 def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
   let summary = "Returns the value to parent op";
 
@@ -523,7 +561,6 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
   let assemblyFormat = "$inputs attr-dict";
 }
 
-
 // Canonicalization patterns.
 
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index fc8f9b23e1e4..a077948fdd31 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -249,7 +249,7 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
   return success();
 }
 
-OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
+OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
 
 //===----------------------------------------------------------------------===//
 // ConstSizeOp
@@ -266,6 +266,26 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
 
 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
 
+//===----------------------------------------------------------------------===//
+// IndexToSizeOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
+  // Constant values of both types, `shape.size` and `index`, are represented as
+  // `IntegerAttr`s which makes constant folding simple.
+  if (Attribute arg = operands[0])
+    return arg;
+  return {};
+}
+
+LogicalResult IndexToSizeOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(SizeType::get(context));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // FromExtentsOp
 //===----------------------------------------------------------------------===//
@@ -333,6 +353,26 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
   return builder.getIndexTensorAttr(type.getShape());
 }
 
+//===----------------------------------------------------------------------===//
+// SizeToIndexOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
+  // Constant values of both types, `shape.size` and `index`, are represented as
+  // `IntegerAttr`s which makes constant folding simple.
+  if (Attribute arg = operands[0])
+    return arg;
+  return {};
+}
+
+LogicalResult SizeToIndexOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(IndexType::get(context));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SplitAtOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 23147e557a15..106171de6087 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -108,6 +108,60 @@ func @no_fold(%arg0: index) -> !shape.shape {
 }
 
 // -----
+// Cast constant size to index and fold it away.
+// CHECK-LABEL: func @const_size_to_index
+func @const_size_to_index() -> index {
+  // CHECK-NOT: shape.index_cast
+  %cs = shape.const_size 123
+  // CHECK: constant 123 : index
+  %ci = shape.size_to_index %cs
+  return %ci : index
+}
+
+// -----
+// Cast constant index to size and fold it away.
+// CHECK-LABEL: func @const_index_to_size
+func @const_index_to_size() -> !shape.size {
+  // CHECK-NOT: index_cast
+  %ci = constant 123 : index
+  // CHECK: shape.const_size 123
+  %cs = shape.index_to_size %ci
+  return %cs : !shape.size
+}
+
+// -----
+// Cast constant index to size, then back, and fold it away.
+// CHECK-LABEL: func @const_index_to_size_to_index
+func @const_index_to_size_to_index() -> index {
+  // CHECK-NOT: shape.index_cast
+  %ci0 = constant 123 : index
+  %cs0 = shape.index_to_size %ci0
+  // CHECK: %[[CI:.*]] = constant 123 : index
+  // CHECK-NEXT: return %[[CI]] : index
+  %ci1 = shape.size_to_index %cs0
+  return %ci1 : index
+}
+
+// -----
+// No folding.
+// CHECK-LABEL: func @nonfoldable_size_to_index
+func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
+  // CHECK: shape.size_to_index
+  %ci = shape.size_to_index %cs
+  return %ci : index
+}
+
+// -----
+// No folding.
+// CHECK-LABEL: func @nonfoldable_index_to_size
+func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
+  // CHECK: shape.index_to_size
+  %cs = shape.index_to_size %ci
+  return %cs : !shape.size
+}
+
+// -----
+
 // Canonicalization of shape.get_extent
 
 // Basic folding.


        


More information about the Mlir-commits mailing list