[Mlir-commits] [mlir] f959585 - [MLIR][Shape] Fold `shape.shape_eq`

Frederik Gossen llvmlistbot at llvm.org
Mon Jul 20 05:26:06 PDT 2020


Author: Frederik Gossen
Date: 2020-07-20T12:25:53Z
New Revision: f9595857b9f868fc7724ea767a8fd984d02848ff

URL: https://github.com/llvm/llvm-project/commit/f9595857b9f868fc7724ea767a8fd984d02848ff
DIFF: https://github.com/llvm/llvm-project/commit/f9595857b9f868fc7724ea767a8fd984d02848ff.diff

LOG: [MLIR][Shape] Fold `shape.shape_eq`

Fold `shape.shape_eq`.

Differential Revision: https://reviews.llvm.org/D82533

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 090b4c6f4abb..85082419b30e 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -154,6 +154,7 @@ def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
   let results = (outs I1:$result);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+  let hasFolder = 1;
 }
 
 def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index b983968b124d..92392b069a04 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -475,6 +475,20 @@ void ConstSizeOp::getAsmResultNames(
 
 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
 
+//===----------------------------------------------------------------------===//
+// ShapeEqOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
+  auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (lhs == nullptr)
+    return {};
+  auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (rhs == nullptr)
+    return {};
+  return BoolAttr::get(lhs == rhs, getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // IndexToSizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 4e320f303b18..a58b230f7981 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -596,3 +596,56 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
+
+// -----
+
+// Fold `shape_eq` for equal and constant shapes.
+// CHECK-LABEL: @shape_eq_fold_1
+func @shape_eq_fold_1() -> i1 {
+  // CHECK: %[[RESULT:.*]] = constant true
+  // CHECK: return %[[RESULT]] : i1
+  %a = shape.const_shape [1, 2, 3]
+  %b = shape.const_shape [1, 2, 3]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}
+
+// -----
+
+// Fold `shape_eq` for 
diff erent but constant shapes of same length.
+// CHECK-LABEL: @shape_eq_fold_0
+func @shape_eq_fold_0() -> i1 {
+  // CHECK: %[[RESULT:.*]] = constant false
+  // CHECK: return %[[RESULT]] : i1
+  %a = shape.const_shape [1, 2, 3]
+  %b = shape.const_shape [4, 5, 6]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}
+
+// -----
+
+// Fold `shape_eq` for 
diff erent but constant shapes of 
diff erent length.
+// CHECK-LABEL: @shape_eq_fold_0
+func @shape_eq_fold_0() -> i1 {
+  // CHECK: %[[RESULT:.*]] = constant false
+  // CHECK: return %[[RESULT]] : i1
+  %a = shape.const_shape [1, 2, 3, 4, 5, 6]
+  %b = shape.const_shape [1, 2, 3]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}
+
+// -----
+
+// Do not fold `shape_eq` for non-constant shapes.
+// CHECK-LABEL: @shape_eq_do_not_fold
+// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
+func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
+  // CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6]
+  // CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape
+  // CHECK: return %[[RESULT]] : i1
+  %b = shape.const_shape [4, 5, 6]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}


        


More information about the Mlir-commits mailing list