[llvm-branch-commits] [mlir] 5844bc5 - [mlir][Shape] Canonicalize assume_all with one input and tensor_cast of constant_shape

Benjamin Kramer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 8 08:12:19 PST 2020


Author: Benjamin Kramer
Date: 2020-12-08T17:07:24+01:00
New Revision: 5844bc540cafb4330e7625b83371f1dab90528c3

URL: https://github.com/llvm/llvm-project/commit/5844bc540cafb4330e7625b83371f1dab90528c3
DIFF: https://github.com/llvm/llvm-project/commit/5844bc540cafb4330e7625b83371f1dab90528c3.diff

LOG: [mlir][Shape] Canonicalize assume_all with one input and tensor_cast of constant_shape

This allows simplifying some more complicated shape expressions

Differential Revision: https://reviews.llvm.org/D92843

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 52768e49001d..552de7e78f91 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -105,6 +105,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Shape_ConstSizeOp : Shape_Op<"const_size", [
@@ -630,6 +631,7 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]>
   let assemblyFormat = "$inputs attr-dict";
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 
   let verifier = [{ return ::verify(*this); }];
 }

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index fe57f7d7a52e..acb35b916f7e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -271,6 +271,12 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
 //===----------------------------------------------------------------------===//
 // AssumingAllOp
 //===----------------------------------------------------------------------===//
+
+void AssumingAllOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<AssumingAllOneOp>(context);
+}
+
 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
   // Iterate in reverse to first handle all constant operands. They are
   // guaranteed to be the tail of the inputs because this is commutative.
@@ -394,6 +400,11 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
 
 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
 
+void ConstShapeOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<TensorCastConstShape>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // CstrBroadcastableOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index c57ad8c8d17c..43c670a8582e 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -1,4 +1,5 @@
 include "mlir/Dialect/Shape/IR/ShapeOps.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
 
 def AllInputShapesEq : Constraint<CPred< [{
   llvm::all_of($0, [&](mlir::Value val) {
@@ -6,8 +7,16 @@ def AllInputShapesEq : Constraint<CPred< [{
   })
 }]>>;
 
+def HasSingleElement : Constraint<CPred< [{
+  $0.size() == 1
+}]>>;
+
 // Canonicalization patterns.
 
+def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
+                           (replaceWithValue $args),
+                           [(HasSingleElement $args)]>;
+
 def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
   (Shape_ConstWitnessOp ConstBoolAttrTrue)>;
 
@@ -23,3 +32,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
   (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
   (replaceWithValue $arg)>;
 
+def TensorCastConstShape : Pat <
+  (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 56a6ef74f54e..9cb01da75901 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -427,20 +427,23 @@ func @f() {
 
 // -----
 
-// assuming_all should not be removed if not all witnesses are statically passing.
+// assuming_all should not be removed if more than one witness is not
+// statically passing
 //
 // Additionally check that the attribute is moved to the end as this op is
 // commutative.
 // CHECK-LABEL: func @f
 func @f() {
-  // CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source"
-  // CHECK-NEXT: shape.assuming_all %[[UNKNOWN]]
+  // CHECK-NEXT: %[[UNKNOWN1:.*]] = "test.source"
+  // CHECK-NEXT: %[[UNKNOWN2:.*]] = "test.source"
+  // CHECK-NEXT: shape.assuming_all %[[UNKNOWN1]], %[[UNKNOWN2]]
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
   %0 = shape.const_witness true
   %1 = "test.source"() : () -> !shape.witness
-  %2 = shape.assuming_all %0, %1
-  "consume.witness"(%2) : (!shape.witness) -> ()
+  %2 = "test.source"() : () -> !shape.witness
+  %3 = shape.assuming_all %0, %1, %2
+  "consume.witness"(%3) : (!shape.witness) -> ()
   return
 }
 
@@ -854,3 +857,28 @@ func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex>
   %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
   return %casted : tensor<?xindex>
 }
+
+// -----
+
+// Fold assuming_all with a single input
+// CHECK-LABEL: @fold_assuming_all_single_element
+func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
+  // CHECK-NOT: assuming_all
+  %0 = "test.source"() : () -> (!shape.witness)
+  %1 = shape.assuming_all %0
+  "consume.witness"(%1) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+
+// Fold tensor_cast of a const_shape to const_shape
+// CHECK-LABEL: @fold_tensor_cast_of_const_shape
+func @fold_tensor_cast_of_const_shape(%arg: tensor<?xindex>) {
+  // CHECK-NOT: tensor_cast
+  %0 = shape.const_shape [2] : tensor<?xindex>
+  %1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
+  %2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor<?xindex>
+  "consume.witness"(%2) : (!shape.witness) -> ()
+  return
+}


        


More information about the llvm-branch-commits mailing list