[Mlir-commits] [mlir] 8b109bc - [mlir, shape] Add max/min folder for simple case

Jacques Pienaar llvmlistbot at llvm.org
Tue Apr 6 20:23:11 PDT 2021


Author: Jacques Pienaar
Date: 2021-04-06T20:22:42-07:00
New Revision: 8b109bc2eae0d33a140982c02c77501932bfa394

URL: https://github.com/llvm/llvm-project/commit/8b109bc2eae0d33a140982c02c77501932bfa394
DIFF: https://github.com/llvm/llvm-project/commit/8b109bc2eae0d33a140982c02c77501932bfa394.diff

LOG: [mlir,shape] Add max/min folder for simple case

When both arguments are the same for these ops, propagate this argument.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 0b8c26dc91565..41e6f8a2a5627 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -416,6 +416,8 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
   let assemblyFormat = [{
     $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
   }];
+
+  let hasFolder = 1;
 }
 
 def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
@@ -433,6 +435,8 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
   let assemblyFormat = [{
     $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
   }];
+
+  let hasFolder = 1;
 }
 
 def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index bb7ed5cf05cec..388a3a5763b12 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -937,6 +937,28 @@ void NumElementsOp::build(OpBuilder &builder, OperationState &result,
   return build(builder, result, type, shape);
 }
 
+//===----------------------------------------------------------------------===//
+// MaxOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
+  // If operands are equal, just propagate one.
+  if (lhs() == rhs())
+    return lhs();
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// MinOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
+  // If operands are equal, just propagate one.
+  if (lhs() == rhs())
+    return lhs();
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b0c12ea0b1499..86ac4c9af9632 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1188,3 +1188,23 @@ func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
   %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
   return %1 : tensor<3xindex>
 }
+
+// ----
+
+// CHECK-LABEL: max_same_arg
+// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
+func @max_same_arg(%a: !shape.shape) -> !shape.shape {
+  %1 = shape.max %a, %a : !shape.shape, !shape.shape -> !shape.shape
+  // CHECK: return %[[SHAPE]]
+  return %1 : !shape.shape
+}
+
+// ----
+
+// CHECK-LABEL: min_same_arg
+// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
+func @min_same_arg(%a: !shape.shape) -> !shape.shape {
+  %1 = shape.min %a, %a : !shape.shape, !shape.shape -> !shape.shape
+  // CHECK: return %[[SHAPE]]
+  return %1 : !shape.shape
+}


        


More information about the Mlir-commits mailing list