[Mlir-commits] [mlir] 783a351 - [MLIR][Shape] Allow `shape.mul` to operate in indices

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 06:25:58 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T13:25:40Z
New Revision: 783a351785c14b7c2eb9f65bd40d37be11cbf38b

URL: https://github.com/llvm/llvm-project/commit/783a351785c14b7c2eb9f65bd40d37be11cbf38b
DIFF: https://github.com/llvm/llvm-project/commit/783a351785c14b7c2eb9f65bd40d37be11cbf38b.diff

LOG: [MLIR][Shape] Allow `shape.mul` to operate in indices

Differential Revision: https://reviews.llvm.org/D84437

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 32d6ebafff32..425cf917283b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -307,18 +307,25 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
   let results = (outs Shape_ShapeOrSizeType:$result);
 }
 
-def Shape_MulOp : Shape_Op<"mul", [Commutative, SameOperandsAndResultType]> {
-  let summary = "Multiplication of sizes";
+def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
+  let summary = "Multiplication of sizes and indices";
   let description = [{
-    Multiplies two valid sizes as follows:
-    - lhs * rhs = unknown if either lhs or rhs unknown;
-    - lhs * rhs = (int)lhs * (int)rhs if both known;
+    Multiplies two sizes or indices. If either operand is an error it will be
+    propagated to the result. The operands can be of type `size` or `index`. If
+    at least one of the operands can hold an error, i.e. if it is of type `size`,
+    then also the result must be of type `size`. If error propagation is not
+    possible because both operands are of type `index` then the result must also
+    be of type `index`.
   }];
 
-  let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
-  let results = (outs Shape_SizeType:$result);
+  let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
+  let results = (outs Shape_SizeOrIndexType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict";
+  let assemblyFormat = [{
+    $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
 }
 
 def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3bdc5cc39a7b..2f641300c491 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -28,6 +28,13 @@ static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
 }
 
+static bool isErrorPropagationPossible(ArrayRef<Type> operandTypes) {
+  for (Type ty : operandTypes)
+    if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
+      return true;
+  return false;
+}
+
 ShapeDialect::ShapeDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addOperations<
@@ -539,9 +546,7 @@ static LogicalResult verify(GetExtentOp op) {
   Type shapeTy = op.shape().getType();
   Type dimTy = op.dim().getType();
   Type extentTy = op.extent().getType();
-  bool errorPropagationPossible =
-      shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
-  if (errorPropagationPossible) {
+  if (isErrorPropagationPossible({shapeTy, dimTy})) {
     if (!extentTy.isa<SizeType>())
       op.emitError()
           << "if at least one of the operands can hold error values then the "
@@ -593,9 +598,8 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(shape::RankOp op) {
-  Type argTy = op.shape().getType();
-  Type resultTy = op.rank().getType();
-  if (argTy.isa<ShapeType>() && !resultTy.isa<SizeType>())
+  if (op.shape().getType().isa<ShapeType>() &&
+      !op.rank().getType().isa<SizeType>())
     return op.emitOpError()
            << "if operand is of type `shape` then the result must be of type "
               "`size` to propagate potential errors";
@@ -672,6 +676,25 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
   return builder.getIndexAttr(product.getLimitedValue());
 }
 
+//===----------------------------------------------------------------------===//
+// MulOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(MulOp op) {
+  Type resultTy = op.result().getType();
+  if (isErrorPropagationPossible({op.lhs().getType(), op.rhs().getType()})) {
+    if (!resultTy.isa<SizeType>())
+      return op.emitOpError()
+             << "if at least one of the operands can hold error values then "
+                "the result must be of type `size` to propagate them";
+  } else {
+    if (resultTy.isa<SizeType>())
+      return op.emitError() << "if none of the operands can hold error values "
+                               "then the result must be of type `index`";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
@@ -685,15 +708,13 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
 }
 
 static LogicalResult verify(ShapeOfOp op) {
-  Type argTy = op.arg().getType();
   Type resultTy = op.result().getType();
-  if (argTy.isa<ValueShapeType>()) {
+  if (isErrorPropagationPossible(op.arg().getType())) {
     if (!resultTy.isa<ShapeType>())
       return op.emitOpError()
              << "if operand is of type `value_shape` then the result must be "
                 "of type `shape` to propagate potential error shapes";
   } else {
-    assert(argTy.isa<ShapedType>());
     if (resultTy != getExtentTensorType(op.getContext()))
       return op.emitOpError() << "if operand is a shaped type then the result "
                                  "must be an extent tensor";

diff  --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index 467f3d33ce23..bb2b03b8ec08 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -38,8 +38,8 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
   // Generate reduce operator.
   Block *body = reduce.getBody();
   OpBuilder b = OpBuilder::atBlockEnd(body);
-  Value product =
-      b.create<MulOp>(loc, body->getArgument(1), body->getArgument(2));
+  Value product = b.create<MulOp>(loc, b.getType<SizeType>(),
+                                  body->getArgument(1), body->getArgument(2));
   b.create<YieldOp>(loc, product);
 
   rewriter.replaceOp(op, reduce.result());

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 908acabe5345..8236c6f27975 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -24,10 +24,19 @@ func @shape_id(%shape : !shape.shape) -> !shape.shape {
 // CHECK-LABEL: @binary_ops
 // CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
 func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
+  // CHECK: addi %[[LHS]], %[[RHS]] : index
   %sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size
-  // CHECK-NEXT: addi %[[LHS]], %[[RHS]] : index
-  %product = shape.mul %lhs, %rhs
-  // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index
+  return
+}
+
+// -----
+
+// Lower binary ops.
+// CHECK-LABEL: @binary_ops
+// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
+func @binary_ops(%lhs : index, %rhs : index) {
+  // CHECK: muli %[[LHS]], %[[RHS]] : index
+  %product = shape.mul %lhs, %rhs : index, index -> index
   return
 }
 

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d7e9e40ed3f2..b4900e491fb8 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -6,6 +6,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
     ^bb0(%index: index, %dim: !shape.size):
       shape.yield %dim : !shape.size
   }
+  return
 }
 
 // -----
@@ -18,6 +19,7 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
           : (!shape.size, !shape.size) -> !shape.size
       shape.yield %new_acc : !shape.size
   }
+  return
 }
 
 // -----
@@ -28,6 +30,7 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
     ^bb0(%index: index, %dim: f32, %lci: !shape.size):
       shape.yield
   }
+  return
 }
 
 // -----
@@ -38,6 +41,7 @@ func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
     ^bb0(%index: index, %dim: f32, %lci: index):
       shape.yield
   }
+  return
 }
 
 // -----
@@ -48,6 +52,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
     ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
       shape.yield
   }
+  return
 }
 
 // -----
@@ -58,6 +63,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
     ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
       shape.yield %dim, %dim : !shape.size, !shape.size
   }
+  return
 }
 
 // -----
@@ -69,6 +75,7 @@ func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
       %c0 = constant 1 : index
       shape.yield %c0 : index
   }
+  return
 }
 
 // -----
@@ -85,6 +92,7 @@ func @shape_of(%value_arg : !shape.value_shape,
                %shaped_arg : tensor<?x3x4xf32>) {
   // expected-error at +1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}}
   %0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
+  return
 }
 
 // -----
@@ -93,6 +101,7 @@ func @shape_of(%value_arg : !shape.value_shape,
                %shaped_arg : tensor<?x3x4xf32>) {
   // expected-error at +1 {{if operand is a shaped type then the result must be an extent tensor}}
   %1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
+  return
 }
 
 // -----
@@ -100,6 +109,7 @@ func @shape_of(%value_arg : !shape.value_shape,
 func @rank(%arg : !shape.shape) {
   // expected-error at +1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
   %0 = shape.rank %arg : !shape.shape -> index
+  return
 }
 
 // -----
@@ -120,3 +130,19 @@ func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
   return %result : index
 }
 
+// -----
+
+func @mul_error_free(%arg : index) -> !shape.size {
+  // expected-error at +1 {{if none of the operands can hold error values then the result must be of type `index`}}
+  %result = shape.mul %arg, %arg : index, index -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
+func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
+  // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
+  %result = shape.mul %lhs, %rhs : !shape.size, index -> index
+  return %result : index
+}
+

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index b6b839251a88..3a0cb7781ec7 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -9,6 +9,7 @@ func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
   %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
     ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
       %acc_next = shape.mul %acc, %extent
+          : !shape.size, !shape.size -> !shape.size
       shape.yield %acc_next : !shape.size
   }
   return %num_elements : !shape.size
@@ -19,7 +20,7 @@ func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
   %init = constant 1 : index
   %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
     ^bb0(%index : index, %extent : index, %acc : index):
-      %acc_next = muli %acc, %extent : index
+      %acc_next = shape.mul %acc, %extent : index, index -> index
       shape.yield %acc_next : index
   }
   return %num_elements : index
@@ -110,9 +111,13 @@ func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
   return
 }
 
-func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size {
-  %product = shape.mul %lhs, %rhs
-  return %product: !shape.size
+func @mul(%size_arg : !shape.size, %index_arg : index) {
+  %size_prod = shape.mul %size_arg, %size_arg
+      : !shape.size, !shape.size -> !shape.size
+  %index_prod = shape.mul %index_arg, %index_arg : index, index -> index
+  %mixed_prod = shape.mul %size_arg, %index_arg
+      : !shape.size, index -> !shape.size
+  return
 }
 
 func @const_size() {


        


More information about the Mlir-commits mailing list