[Mlir-commits] [mlir] 5fff169 - [shape] More constant folding

Sean Silva llvmlistbot at llvm.org
Fri Apr 24 16:10:39 PDT 2020


Author: Sean Silva
Date: 2020-04-24T16:10:19-07:00
New Revision: 5fff169daa11f3e574fe43f93431d36593905cb7

URL: https://github.com/llvm/llvm-project/commit/5fff169daa11f3e574fe43f93431d36593905cb7
DIFF: https://github.com/llvm/llvm-project/commit/5fff169daa11f3e574fe43f93431d36593905cb7.diff

LOG: [shape] More constant folding

- shape split_at
- shape.broadcast
- shape.concat
- shape.to_extent_tensor

Differential Revision: https://reviews.llvm.org/D78821

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 57f60068d7cd..f54456b862fa 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -152,6 +152,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
   let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs,
                    OptionalAttr<StrAttr>:$error);
   let results = (outs Shape_ShapeType:$result);
+
+  let hasFolder = 1;
 }
 
 def Shape_ConstShapeOp : Shape_Op<"const_shape",
@@ -225,6 +227,8 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", []> {
 
   let arguments = (ins Shape_ShapeType:$input);
   let results = (outs IndexTensor:$result);
+
+  let hasFolder = 1;
 }
 
 def Shape_JoinOp : Shape_Op<"join", []> {
@@ -376,6 +380,7 @@ def Shape_SplitAtOp : Shape_Op<"split_at",
 
   let arguments = (ins Shape_ShapeType:$operand, I32:$index);
   let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
+  let hasFolder = 1;
 }
 
 def Shape_ConcatOp : Shape_Op<"concat",
@@ -393,6 +398,8 @@ def Shape_ConcatOp : Shape_Op<"concat",
 
   let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
   let results = (outs Shape_ShapeType:$result);
+
+  let hasFolder = 1;
 }
 
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt
index 982f73385076..4ed02acc3d46 100644
--- a/mlir/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRShape
   )
 target_link_libraries(MLIRShape
   PUBLIC
+  MLIRDialect
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRSideEffects

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3d3c1a9b6454..4a1c0f1d5128 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Shape/IR/Shape.h"
 
+#include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -87,6 +88,26 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
   }
 }
 
+//===----------------------------------------------------------------------===//
+// BroadcastOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0] || !operands[1])
+    return nullptr;
+  auto lhsShape = llvm::to_vector<6>(
+      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  auto rhsShape = llvm::to_vector<6>(
+      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  SmallVector<int64_t, 6> resultShape;
+  // If the shapes are not compatible, we can't fold it.
+  // TODO: Fold to an "error".
+  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
+    return nullptr;
+  Builder builder(getContext());
+  return builder.getI64TensorAttr(resultShape);
+}
+
 //===----------------------------------------------------------------------===//
 // ConstShapeOp
 //===----------------------------------------------------------------------===//
@@ -176,6 +197,27 @@ LogicalResult SplitAtOp::inferReturnTypes(
   return success();
 }
 
+LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
+                              SmallVectorImpl<OpFoldResult> &results) {
+  if (!operands[0] || !operands[1])
+    return failure();
+  auto shapeVec = llvm::to_vector<6>(
+      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  auto shape = llvm::makeArrayRef(shapeVec);
+  auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
+  // Verify that the split point is in the correct range.
+  // TODO: Constant fold to an "error".
+  int64_t rank = shape.size();
+  if (!(-rank <= splitPoint && splitPoint <= rank))
+    return failure();
+  if (splitPoint < 0)
+    splitPoint += shape.size();
+  Builder builder(operands[0].getContext());
+  results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint)));
+  results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint)));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConcatOp
 //===----------------------------------------------------------------------===//
@@ -189,6 +231,35 @@ LogicalResult ConcatOp::inferReturnTypes(
   return success();
 }
 
+OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0] || !operands[1])
+    return nullptr;
+  auto lhsShape = llvm::to_vector<6>(
+      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  auto rhsShape = llvm::to_vector<6>(
+      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  SmallVector<int64_t, 6> resultShape;
+  resultShape.append(lhsShape.begin(), lhsShape.end());
+  resultShape.append(rhsShape.begin(), rhsShape.end());
+  Builder builder(getContext());
+  return builder.getI64TensorAttr(resultShape);
+}
+
+//===----------------------------------------------------------------------===//
+// ToExtentTensorOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0])
+    return nullptr;
+  Builder builder(getContext());
+  auto shape = llvm::to_vector<6>(
+      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
+                                    builder.getIndexType());
+  return DenseIntElementsAttr::get(type, shape);
+}
+
 namespace mlir {
 namespace shape {
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index fad31d840523..ee69f90553d9 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -canonicalize <%s | FileCheck %s --dump-input=fail
+// RUN: mlir-opt -split-input-file -canonicalize <%s | FileCheck %s --dump-input=fail
 
 // -----
 // CHECK-LABEL: func @f
@@ -7,3 +7,82 @@ func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
   %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
   return %0 : !shape.shape
 }
+
+// -----
+// Basic case.
+// CHECK-LABEL: func @f
+func @f() -> (!shape.shape, !shape.shape) {
+  // CHECK: shape.const_shape [2, 3]
+  // CHECK: shape.const_shape [4, 5]
+  %c2 = constant 2 : i32
+  %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
+  %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
+  return %head, %tail : !shape.shape, !shape.shape
+
+}
+
+// -----
+// Negative split point.
+// CHECK-LABEL: func @f
+func @f() -> (!shape.shape, !shape.shape) {
+  // CHECK: shape.const_shape [2, 3, 4]
+  // CHECK: shape.const_shape [5]
+  %c-1 = constant -1 : i32
+  %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
+  %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
+  return %head, %tail : !shape.shape, !shape.shape
+}
+
+// -----
+// Out of range split point. No folding.
+// CHECK-LABEL: func @f
+func @f() -> (!shape.shape, !shape.shape) {
+  // CHECK: shape.split_at
+  %c5 = constant 5 : i32
+  %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
+  %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
+  return %head, %tail : !shape.shape, !shape.shape
+}
+
+// -----
+// Basic case.
+// CHECK-LABEL: func @f
+func @f() -> !shape.shape {
+  // CHECK: shape.const_shape [7, 2]
+  %0 = shape.const_shape [1, 2]
+  %1 = shape.const_shape [7, 1]
+  %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  return %2 : !shape.shape
+}
+
+// -----
+// Incompatible shapes. No folding.
+// CHECK-LABEL: func @f
+func @f() -> !shape.shape {
+  // CHECK: shape.broadcast
+  %0 = shape.const_shape [2]
+  %1 = shape.const_shape [7]
+  %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  return %2 : !shape.shape
+}
+
+// -----
+// Basic case.
+// CHECK-LABEL: func @f
+func @f() -> !shape.shape {
+  // CHECK: shape.const_shape [0, 1, 2, 3]
+  %lhs = shape.const_shape [0, 1]
+  %rhs = shape.const_shape [2, 3]
+  %0 = "shape.concat"(%lhs, %rhs) : (!shape.shape, !shape.shape) -> !shape.shape
+  return %0 : !shape.shape
+}
+
+// -----
+// Basic case.
+// CHECK-LABEL: func @f
+func @f() -> tensor<2xindex> {
+  // CHECK: constant dense<[0, 1]> : tensor<2xindex>
+  %cs = shape.const_shape [0, 1]
+  %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex>
+  return %0 : tensor<2xindex>
+}


        


More information about the Mlir-commits mailing list