[Mlir-commits] [mlir] d690cbf - Add DivOp to the Shape dialect
Jacques Pienaar
llvmlistbot at llvm.org
Thu Feb 18 16:59:09 PST 2021
Author: Jing Pu
Date: 2021-02-18T16:58:47-08:00
New Revision: d690cbf821f1d4e13bdf97779ffb332624890fcb
URL: https://github.com/llvm/llvm-project/commit/d690cbf821f1d4e13bdf97779ffb332624890fcb
DIFF: https://github.com/llvm/llvm-project/commit/d690cbf821f1d4e13bdf97779ffb332624890fcb.diff
LOG: Add DivOp to the Shape dialect
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D96907
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 30e8ca10150d..ecd7f3ae0c95 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -138,6 +138,36 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
let hasFolder = 1;
}
+def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
+ let summary = "Division of sizes and indices";
+ let description = [{
+ Divides two sizes or indices. If either operand is an error it will be
+ propagated to the result. The operands can be of type `size` or `index`.
+ If at least one of the operands can hold an error, i.e. if it is of type
+ `size`, the result must be of type `size`. If error propagation is not
+ possible because both operands are of type `index` then the result may be
+ of type `size` or `index`. If both operands and result are of type `index`,
+ their runtime values could be negative. The result is rounded toward
+ negative infinity, i.e. floor(lhs / rhs), such that
+
+ div(lhs, rhs) * rhs + mod(lhs, rhs) = lhs
+
+ always holds. If any of the values is of type `size`, the behavior for
+ negative value is undefined.
+ }];
+
+ let arguments = (ins Shape_SizeOrIndexType:$lhs,
+ Shape_SizeOrIndexType:$rhs);
+ let results = (outs Shape_SizeOrIndexType:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+ }];
+
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
+ let hasFolder = 1;
+}
+
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
let summary = "Returns whether the input shapes or extent tensors are equal";
let description = [{
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index b1199fb17b20..6146239cd4f8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -600,6 +600,30 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
return operands[0];
}
+//===----------------------------------------------------------------------===//
+// DivOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
+ auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+ if (!lhs)
+ return nullptr;
+ auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+ if (!rhs)
+ return nullptr;
+
+ // Division in APInt does not follow floor(lhs, rhs) when the result is
+ // negative. Rather, APInt rounds toward zero.
+ APInt quotient, remainder;
+ APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
+ if (quotient.isNegative() && !remainder.isNullValue()) {
+ quotient -= 1;
+ }
+
+ Type indexTy = IndexType::get(getContext());
+ return IntegerAttr::get(indexTy, quotient);
+}
+
//===----------------------------------------------------------------------===//
// ShapeEqOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 3a04c16cb9f0..d5bf3f711985 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -950,6 +950,71 @@ func @fold_mul_mixed() -> !shape.size {
// -----
+// Fold `div` for constant sizes.
+// CHECK-LABEL: @fold_div_size
+func @fold_div_size() -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = shape.const_size 3
+ // CHECK: return %[[RESULT]] : !shape.size
+ %c2 = shape.const_size 10
+ %c3 = shape.const_size 3
+ %result = shape.div %c2, %c3 : !shape.size, !shape.size -> !shape.size
+ return %result : !shape.size
+}
+
+// -----
+
+// Fold `div` for constant indices.
+// CHECK-LABEL: @fold_div_index
+func @fold_div_index() -> index {
+ // CHECK: %[[RESULT:.*]] = constant 2 : index
+ // CHECK: return %[[RESULT]] : index
+ %c2 = constant 10 : index
+ %c3 = constant 4 : index
+ %result = shape.div %c2, %c3 : index, index -> index
+ return %result : index
+}
+
+// -----
+
+// Fold `div` for constant indices and lhs is negative.
+// CHECK-LABEL: @fold_div_index_neg_lhs
+func @fold_div_index_neg_lhs() -> index {
+ // CHECK: %[[RESULT:.*]] = constant -3 : index
+ // CHECK: return %[[RESULT]] : index
+ %c2 = constant -10 : index
+ %c3 = constant 4 : index
+ %result = shape.div %c2, %c3 : index, index -> index
+ return %result : index
+}
+
+// -----
+
+// Fold `div` for constant indices and rhs is negative.
+// CHECK-LABEL: @fold_div_index_neg_rhs
+func @fold_div_index_neg_rhs() -> index {
+ // CHECK: %[[RESULT:.*]] = constant -3 : index
+ // CHECK: return %[[RESULT]] : index
+ %c2 = constant 10 : index
+ %c3 = constant -4 : index
+ %result = shape.div %c2, %c3 : index, index -> index
+ return %result : index
+}
+
+// -----
+
+// Fold `div` for mixed constants.
+// CHECK-LABEL: @fold_div_mixed
+func @fold_div_mixed() -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = shape.const_size 4
+ // CHECK: return %[[RESULT]] : !shape.size
+ %c2 = shape.const_size 12
+ %c3 = constant 3 : index
+ %result = shape.div %c2, %c3 : !shape.size, index -> !shape.size
+ return %result : !shape.size
+}
+
+// -----
+
// Fold index_cast when already on index.
// CHECK-LABEL: @fold_index_cast_on_index
func @fold_index_cast_on_index(%arg: index) -> index {
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 57195baad7c0..1a5735ccbb5e 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -129,6 +129,15 @@ func @mul(%size_arg : !shape.size, %index_arg : index) {
return
}
+func @div(%size_arg : !shape.size, %index_arg : index) {
+ %size_div = shape.div %size_arg, %size_arg
+ : !shape.size, !shape.size -> !shape.size
+ %index_div = shape.div %index_arg, %index_arg : index, index -> index
+ %mixed_div = shape.div %size_arg, %index_arg
+ : !shape.size, index -> !shape.size
+ return
+}
+
func @add(%size_arg : !shape.size, %index_arg : index) {
%size_sum = shape.add %size_arg, %size_arg
: !shape.size, !shape.size -> !shape.size
More information about the Mlir-commits
mailing list