[Mlir-commits] [mlir] f959585 - [MLIR][Shape] Fold `shape.shape_eq`
Frederik Gossen
llvmlistbot at llvm.org
Mon Jul 20 05:26:06 PDT 2020
Author: Frederik Gossen
Date: 2020-07-20T12:25:53Z
New Revision: f9595857b9f868fc7724ea767a8fd984d02848ff
URL: https://github.com/llvm/llvm-project/commit/f9595857b9f868fc7724ea767a8fd984d02848ff
DIFF: https://github.com/llvm/llvm-project/commit/f9595857b9f868fc7724ea767a8fd984d02848ff.diff
LOG: [MLIR][Shape] Fold `shape.shape_eq`
Fold `shape.shape_eq`.
Differential Revision: https://reviews.llvm.org/D82533
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 090b4c6f4abb..85082419b30e 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -154,6 +154,7 @@ def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
let results = (outs I1:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+ let hasFolder = 1;
}
def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index b983968b124d..92392b069a04 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -475,6 +475,20 @@ void ConstSizeOp::getAsmResultNames(
OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
+//===----------------------------------------------------------------------===//
+// ShapeEqOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
+ auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+ if (lhs == nullptr)
+ return {};
+ auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
+ if (rhs == nullptr)
+ return {};
+ return BoolAttr::get(lhs == rhs, getContext());
+}
+
//===----------------------------------------------------------------------===//
// IndexToSizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 4e320f303b18..a58b230f7981 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -596,3 +596,56 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
+
+// -----
+
+// Fold `shape_eq` for equal and constant shapes.
+// CHECK-LABEL: @shape_eq_fold_1
+func @shape_eq_fold_1() -> i1 {
+ // CHECK: %[[RESULT:.*]] = constant true
+ // CHECK: return %[[RESULT]] : i1
+ %a = shape.const_shape [1, 2, 3]
+ %b = shape.const_shape [1, 2, 3]
+ %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+ return %result : i1
+}
+
+// -----
+
+// Fold `shape_eq` for
diff erent but constant shapes of same length.
+// CHECK-LABEL: @shape_eq_fold_0
+func @shape_eq_fold_0() -> i1 {
+ // CHECK: %[[RESULT:.*]] = constant false
+ // CHECK: return %[[RESULT]] : i1
+ %a = shape.const_shape [1, 2, 3]
+ %b = shape.const_shape [4, 5, 6]
+ %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+ return %result : i1
+}
+
+// -----
+
+// Fold `shape_eq` for
diff erent but constant shapes of
diff erent length.
+// CHECK-LABEL: @shape_eq_fold_0
+func @shape_eq_fold_0() -> i1 {
+ // CHECK: %[[RESULT:.*]] = constant false
+ // CHECK: return %[[RESULT]] : i1
+ %a = shape.const_shape [1, 2, 3, 4, 5, 6]
+ %b = shape.const_shape [1, 2, 3]
+ %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+ return %result : i1
+}
+
+// -----
+
+// Do not fold `shape_eq` for non-constant shapes.
+// CHECK-LABEL: @shape_eq_do_not_fold
+// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
+func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
+ // CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6]
+ // CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape
+ // CHECK: return %[[RESULT]] : i1
+ %b = shape.const_shape [4, 5, 6]
+ %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+ return %result : i1
+}
More information about the Mlir-commits
mailing list