[Mlir-commits] [mlir] 8ef4724 - [mlir][shape] Fold shape.broadcast with one scalar operand
Stephan Herhut
llvmlistbot at llvm.org
Wed Jul 15 10:10:37 PDT 2020
Author: Stephan Herhut
Date: 2020-07-15T18:49:12+02:00
New Revision: 8ef47244b95f7b148e072a19563f6096ed4fe43c
URL: https://github.com/llvm/llvm-project/commit/8ef47244b95f7b148e072a19563f6096ed4fe43c
DIFF: https://github.com/llvm/llvm-project/commit/8ef47244b95f7b148e072a19563f6096ed4fe43c.diff
LOG: [mlir][shape] Fold shape.broadcast with one scalar operand
This folds shape.broadcast where at least one operand is a scalar to the
other operand.
Also add an assemblyFormat for shape.broadcast and shape.concat.
Differential Revision: https://reviews.llvm.org/D83854
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 38bac19d0fa8..1f141a2e705a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -90,6 +90,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeType:$result);
+ let assemblyFormat = "$lhs `,` $rhs attr-dict";
+
let hasFolder = 1;
}
@@ -488,6 +490,7 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
let results = (outs Shape_ShapeType:$result);
+ let assemblyFormat = "$lhs `,` $rhs attr-dict";
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0a0608bbcda4..a6f54053a326 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -237,12 +237,22 @@ static LogicalResult verify(AssumingAllOp op) {
//===----------------------------------------------------------------------===//
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[0] || !operands[1])
+ if (!operands[1])
return nullptr;
- auto lhsShape = llvm::to_vector<6>(
- operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+
auto rhsShape = llvm::to_vector<6>(
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ if (rhsShape.empty())
+ return lhs();
+
+ if (!operands[0])
+ return nullptr;
+
+ auto lhsShape = llvm::to_vector<6>(
+ operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ if (lhsShape.empty())
+ return rhs();
+
SmallVector<int64_t, 6> resultShape;
// If the shapes are not compatible, we can't fold it.
// TODO: Fold to an "error".
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 1665ef73f3e3..4e320f303b18 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -54,7 +54,42 @@ func @f() -> !shape.shape {
// CHECK: shape.const_shape [7, 2]
%0 = shape.const_shape [1, 2]
%1 = shape.const_shape [7, 1]
- %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %2 = shape.broadcast %0, %1
+ return %2 : !shape.shape
+}
+
+// -----
+
+// Rhs is a scalar.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) -> !shape.shape {
+ // CHECK: return %arg0
+ %0 = shape.const_shape []
+ %1 = shape.broadcast %arg0, %0
+ return %1 : !shape.shape
+}
+
+// -----
+
+// Lhs is a scalar.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) -> !shape.shape {
+ // CHECK: return %arg0
+ %0 = shape.const_shape []
+ %1 = shape.broadcast %0, %arg0
+ return %1 : !shape.shape
+}
+
+// -----
+
+// Lhs is a scalar and rhs is constant.
+// CHECK-LABEL: func @f
+func @f() -> !shape.shape {
+ // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3]
+ // CHECK: return %[[CST]]
+ %0 = shape.const_shape []
+ %1 = shape.const_shape [1, 2, 3]
+ %2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@@ -66,7 +101,7 @@ func @f() -> !shape.shape {
// CHECK: shape.broadcast
%0 = shape.const_shape [2]
%1 = shape.const_shape [7]
- %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@@ -78,7 +113,7 @@ func @f() -> !shape.shape {
// CHECK: shape.const_shape [0, 1, 2, 3]
%lhs = shape.const_shape [0, 1]
%rhs = shape.const_shape [2, 3]
- %0 = "shape.concat"(%lhs, %rhs) : (!shape.shape, !shape.shape) -> !shape.shape
+ %0 = shape.concat %lhs, %rhs
return %0 : !shape.shape
}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 94323e856750..3a0bcf713073 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -29,10 +29,10 @@ func @test_shape_num_elements_fixed() {
return
}
-func @test_broadcastable_fixed() {
+func @test_broadcast_fixed() {
%0 = shape.const_shape [10, 1, 57, 92]
%1 = shape.const_shape [4, 57, 92]
- %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %2 = shape.broadcast %0, %1
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
More information about the Mlir-commits
mailing list