[Mlir-commits] [mlir] 25132b3 - [mlir][shape] Use IndexElementsAttr in Shape dialect.

Sean Silva llvmlistbot at llvm.org
Wed May 27 13:40:09 PDT 2020


Author: Sean Silva
Date: 2020-05-27T13:39:49-07:00
New Revision: 25132b36a8b39e7c2b0b28aa73772e57191b6df4

URL: https://github.com/llvm/llvm-project/commit/25132b36a8b39e7c2b0b28aa73772e57191b6df4
DIFF: https://github.com/llvm/llvm-project/commit/25132b36a8b39e7c2b0b28aa73772e57191b6df4.diff

LOG: [mlir][shape] Use IndexElementsAttr in Shape dialect.

Summary:
Index is the proper type for storing shapes when constant folding, so
this fixes the previous code (which was using i64).

Differential Revision: https://reviews.llvm.org/D80600

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a9759fc6a734..406aac2db99a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -102,7 +102,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
     %1 = shape.const_shape [1, 2, 3]
     ```
   }];
-  let arguments = (ins I64ElementsAttr:$shape);
+  let arguments = (ins IndexElementsAttr:$shape);
   let results = (outs Shape_ShapeType:$result);
 
   // TODO: Move this to main so that all shape ops implement these.
@@ -206,13 +206,8 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
   let builders = [
     // Builder that allows passing a simple integer instead of an IntegerAttr.
     OpBuilder<
-      [{
-        OpBuilder &builder, OperationState &result,
-        Value shape, int64_t dim
-      }],
-      [{
-        build(builder, result, shape, builder.getI64IntegerAttr(dim));
-      }]
+      [{OpBuilder &builder, OperationState &result, Value shape, int64_t dim}],
+      [{build(builder, result, shape, builder.getI64IntegerAttr(dim));}]
     >
   ];
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index fa9552fc8694..c4a8b1529817 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -177,7 +177,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
     return nullptr;
   Builder builder(getContext());
-  return builder.getI64TensorAttr(resultShape);
+  return builder.getIndexTensorAttr(resultShape);
 }
 
 //===----------------------------------------------------------------------===//
@@ -215,7 +215,7 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
     ints.push_back(attr.getInt());
   }
   Builder &builder = parser.getBuilder();
-  result.addAttribute("shape", builder.getI64TensorAttr(ints));
+  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
 
   result.types.push_back(ShapeType::get(builder.getContext()));
   return success();
@@ -257,7 +257,7 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
   for (auto attr : operands)
     extents.push_back(attr.cast<IntegerAttr>().getInt());
   Builder builder(getContext());
-  return builder.getI64TensorAttr(extents);
+  return builder.getIndexTensorAttr(extents);
 }
 
 //===----------------------------------------------------------------------===//
@@ -281,14 +281,7 @@ OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
   // TODO: Constant fold this to some kind of constant error.
   if (dimToGet >= (uint64_t)elements.getNumElements())
     return nullptr;
-  // This is a little inconvenient because getValue returns an IntegerAttr
-  // that is not of IndexType, but the result here needs to be of
-  // IndexType.
-  // TODO: Make ConstShapeOp hold an tensor of index instead of i64.
-  Builder builder(getContext());
-  return builder.getIntegerAttr(
-      builder.getIndexType(),
-      elements.getValue<IntegerAttr>({dimToGet}).getInt());
+  return elements.getValue({dimToGet});
 }
 
 //===----------------------------------------------------------------------===//
@@ -309,7 +302,7 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
   if (!type || !type.hasStaticShape())
     return nullptr;
   Builder builder(getContext());
-  return builder.getI64TensorAttr(type.getShape());
+  return builder.getIndexTensorAttr(type.getShape());
 }
 
 //===----------------------------------------------------------------------===//
@@ -343,8 +336,8 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
   if (splitPoint < 0)
     splitPoint += shape.size();
   Builder builder(operands[0].getContext());
-  results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint)));
-  results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint)));
+  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
+  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
   return success();
 }
 
@@ -373,7 +366,7 @@ OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
   resultShape.append(lhsShape.begin(), lhsShape.end());
   resultShape.append(rhsShape.begin(), rhsShape.end());
   Builder builder(getContext());
-  return builder.getI64TensorAttr(resultShape);
+  return builder.getIndexTensorAttr(resultShape);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 018f5b212b4e..23147e557a15 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -15,7 +15,7 @@ func @f() -> (!shape.shape, !shape.shape) {
   // CHECK: shape.const_shape [2, 3]
   // CHECK: shape.const_shape [4, 5]
   %c2 = constant 2 : i32
-  %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
+  %0 = shape.const_shape [2, 3, 4, 5]
   %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
   return %head, %tail : !shape.shape, !shape.shape
 
@@ -28,7 +28,7 @@ func @f() -> (!shape.shape, !shape.shape) {
   // CHECK: shape.const_shape [2, 3, 4]
   // CHECK: shape.const_shape [5]
   %c-1 = constant -1 : i32
-  %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
+  %0 = shape.const_shape [2, 3, 4, 5]
   %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
   return %head, %tail : !shape.shape, !shape.shape
 }
@@ -39,7 +39,7 @@ func @f() -> (!shape.shape, !shape.shape) {
 func @f() -> (!shape.shape, !shape.shape) {
   // CHECK: shape.split_at
   %c5 = constant 5 : i32
-  %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
+  %0 = shape.const_shape [2, 3, 4, 5]
   %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
   return %head, %tail : !shape.shape, !shape.shape
 }


        


More information about the Mlir-commits mailing list