[Mlir-commits] [mlir] 8ef4724 - [mlir][shape] Fold shape.broadcast with one scalar operand

Stephan Herhut llvmlistbot at llvm.org
Wed Jul 15 10:10:37 PDT 2020


Author: Stephan Herhut
Date: 2020-07-15T18:49:12+02:00
New Revision: 8ef47244b95f7b148e072a19563f6096ed4fe43c

URL: https://github.com/llvm/llvm-project/commit/8ef47244b95f7b148e072a19563f6096ed4fe43c
DIFF: https://github.com/llvm/llvm-project/commit/8ef47244b95f7b148e072a19563f6096ed4fe43c.diff

LOG: [mlir][shape] Fold shape.broadcast with one scalar operand

This folds shape.broadcast where at least one operand is a scalar to the
other operand.

Also add an assemblyFormat for shape.broadcast and shape.concat.

Differential Revision: https://reviews.llvm.org/D83854

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 38bac19d0fa8..1f141a2e705a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -90,6 +90,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
                    OptionalAttr<StrAttr>:$error);
   let results = (outs Shape_ShapeType:$result);
 
+  let assemblyFormat = "$lhs `,` $rhs attr-dict";
+
   let hasFolder = 1;
 }
 
@@ -488,6 +490,7 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
   let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
   let results = (outs Shape_ShapeType:$result);
 
+  let assemblyFormat = "$lhs `,` $rhs attr-dict";
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0a0608bbcda4..a6f54053a326 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -237,12 +237,22 @@ static LogicalResult verify(AssumingAllOp op) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[0] || !operands[1])
+  if (!operands[1])
     return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+
   auto rhsShape = llvm::to_vector<6>(
       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  if (rhsShape.empty())
+    return lhs();
+
+  if (!operands[0])
+    return nullptr;
+
+  auto lhsShape = llvm::to_vector<6>(
+      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  if (lhsShape.empty())
+    return rhs();
+
   SmallVector<int64_t, 6> resultShape;
   // If the shapes are not compatible, we can't fold it.
   // TODO: Fold to an "error".

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 1665ef73f3e3..4e320f303b18 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -54,7 +54,42 @@ func @f() -> !shape.shape {
   // CHECK: shape.const_shape [7, 2]
   %0 = shape.const_shape [1, 2]
   %1 = shape.const_shape [7, 1]
-  %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %2 = shape.broadcast %0, %1
+  return %2 : !shape.shape
+}
+
+// -----
+
+// Rhs is a scalar.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) -> !shape.shape {
+  // CHECK: return %arg0
+  %0 = shape.const_shape []
+  %1 = shape.broadcast %arg0, %0
+  return %1 : !shape.shape
+}
+
+// -----
+
+// Lhs is a scalar.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) -> !shape.shape {
+  // CHECK: return %arg0
+  %0 = shape.const_shape []
+  %1 = shape.broadcast %0, %arg0
+  return %1 : !shape.shape
+}
+
+// -----
+
+// Lhs is a scalar and rhs is constant.
+// CHECK-LABEL: func @f
+func @f() -> !shape.shape {
+  // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3]
+  // CHECK: return %[[CST]]
+  %0 = shape.const_shape []
+  %1 = shape.const_shape [1, 2, 3]
+  %2 = shape.broadcast %0, %1
   return %2 : !shape.shape
 }
 
@@ -66,7 +101,7 @@ func @f() -> !shape.shape {
   // CHECK: shape.broadcast
   %0 = shape.const_shape [2]
   %1 = shape.const_shape [7]
-  %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %2 = shape.broadcast %0, %1
   return %2 : !shape.shape
 }
 
@@ -78,7 +113,7 @@ func @f() -> !shape.shape {
   // CHECK: shape.const_shape [0, 1, 2, 3]
   %lhs = shape.const_shape [0, 1]
   %rhs = shape.const_shape [2, 3]
-  %0 = "shape.concat"(%lhs, %rhs) : (!shape.shape, !shape.shape) -> !shape.shape
+  %0 = shape.concat %lhs, %rhs
   return %0 : !shape.shape
 }
 

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 94323e856750..3a0bcf713073 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -29,10 +29,10 @@ func @test_shape_num_elements_fixed() {
   return
 }
 
-func @test_broadcastable_fixed() {
+func @test_broadcast_fixed() {
   %0 = shape.const_shape [10, 1, 57, 92]
   %1 = shape.const_shape [4, 57, 92]
-  %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %2 = shape.broadcast %0, %1
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }


        


More information about the Mlir-commits mailing list