[Mlir-commits] [mlir] 1288ada - [MLIR][Shape] Remove duplicate operands of `shape.assuming_all` op

Frederik Gossen llvmlistbot at llvm.org
Mon May 31 05:38:17 PDT 2021


Author: Frederik Gossen
Date: 2021-05-31T14:37:55+02:00
New Revision: 1288adaa7350ca1b6aba1f535710dbe89f16de65

URL: https://github.com/llvm/llvm-project/commit/1288adaa7350ca1b6aba1f535710dbe89f16de65
DIFF: https://github.com/llvm/llvm-project/commit/1288adaa7350ca1b6aba1f535710dbe89f16de65.diff

LOG: [MLIR][Shape] Remove duplicate operands of `shape.assuming_all` op

Differential Revision: https://reviews.llvm.org/D103403

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2c38710426696..158a6d5763fbb 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -429,11 +429,36 @@ struct AssumingAllToCstrEqCanonicalization
     return success();
   }
 };
+
+template <typename OpTy>
+struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Find unique operands.
+    SmallVector<Value, 2> unique;
+    for (Value v : op.getOperands()) {
+      if (!llvm::is_contained(unique, v))
+        unique.push_back(v);
+    }
+
+    // Reduce op to equivalent with unique operands.
+    if (unique.size() < op.getNumOperands()) {
+      rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
+                                        op->getAttrs());
+      return success();
+    }
+
+    return failure();
+  }
+};
 } // namespace
 
 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context);
+  patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
+               RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
 }
 
 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
@@ -508,30 +533,6 @@ static LogicalResult verify(BroadcastOp op) {
 }
 
 namespace {
-template <typename OpTy>
-struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy op,
-                                PatternRewriter &rewriter) const override {
-    // Find unique operands.
-    SmallVector<Value, 2> unique;
-    for (Value v : op.getOperands()) {
-      if (!llvm::is_contained(unique, v))
-        unique.push_back(v);
-    }
-
-    // Reduce op to equivalent with unique operands.
-    if (unique.size() < op.getNumOperands()) {
-      rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
-                                        op->getAttrs());
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 template <typename OpTy>
 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 6e02438391320..a929e4d0d3e1a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -477,6 +477,19 @@ func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
   return %2 : !shape.witness
 }
 
+// -----
+// `assuming_all` with duplicate operands.
+// CHECK-LABEL: func @assuming_all_duplicate_operands
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
+func @assuming_all_duplicate_operands(%arg0 : tensor<?xindex>,
+    %arg1 : tensor<?xindex>) -> !shape.witness {
+  // CHECK: %[[RES:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
+  // CHECK: return %[[RES]]
+  %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
+  %1 = shape.assuming_all %0, %0, %0
+  return %1 : !shape.witness
+}
+
 // -----
 // `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed.
 // CHECK-LABEL: func @assuming_all_to_cstr_eq


        


More information about the Mlir-commits mailing list