[llvm-branch-commits] [mlir] 6e8ef3b - [mlir][Linalg] Make Fill operation work on tensors.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 22 14:44:22 PST 2021
Author: MaheshRavishankar
Date: 2021-01-22T14:39:27-08:00
New Revision: 6e8ef3b76ab65960edd6ee99f387e75564d8d9db
URL: https://github.com/llvm/llvm-project/commit/6e8ef3b76ab65960edd6ee99f387e75564d8d9db
DIFF: https://github.com/llvm/llvm-project/commit/6e8ef3b76ab65960edd6ee99f387e75564d8d9db.diff
LOG: [mlir][Linalg] Make Fill operation work on tensors.
Depends on D95109
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile-tensors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 26db4c2f6735..436dab1ade2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -148,8 +148,9 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
}
def FillOp : LinalgStructured_Op<"fill", []> {
- let arguments = (ins AnyStridedMemRef:$output,
+ let arguments = (ins AnyShaped:$output,
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
+ let results = (outs Optional<AnyRankedTensor>:$result);
let extraClassDeclaration = libraryCallName # [{
ValueRange inputs() { return {}; }
ValueRange outputs() { return getOperands().take_front(); }
@@ -174,6 +175,14 @@ def FillOp : LinalgStructured_Op<"fill", []> {
}
}];
+ let assemblyFormat = [{
+ `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
+ }];
+
+ let builders = [
+ OpBuilderDAG<(ins "Value":$output, "Value":$value)>
+ ];
+
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b500eefa9d0c..a6f3576c4240 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -220,6 +220,16 @@ static LogicalResult foldMemRefCast(Operation *op) {
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.
+//===----------------------------------------------------------------------===//
+// FillOp
+//===----------------------------------------------------------------------===//
+
+void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
+ Value value) {
+ build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
+ value);
+}
+
//===----------------------------------------------------------------------===//
// GenericOps
//===----------------------------------------------------------------------===//
@@ -1726,6 +1736,10 @@ static LogicalResult verify(FillOp op) {
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
+ if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+ return op.emitOpError(
+ "expected fill op with no result value to use memref type");
+ }
return success();
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a3ef242c29f9..6579add14c50 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -659,3 +659,41 @@ func @pad_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
} : tensor<?x4xi32> to tensor<?x9xi32>
return %0 : tensor<?x9xi32>
}
+
+// -----
+
+func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
+{
+ %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+ // expected-error @+1 {{expected fill op with no result value to use memref type}}
+ linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32
+}
+
+// -----
+
+func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
+{
+ // expected-error @+1 {{unexpected #results > #outputs}}
+ %0 = linalg.fill(%arg0, %arg1) : memref<?x?xf32>, f32 -> memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @illegal_fill_memref_with_tensor_return
+ (%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
+{
+ // expected-error @+1 {{unexpected #results > #outputs}}
+ %0 = linalg.fill(%arg0, %arg1) : memref<?x?xf32>, f32 -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @illegal_fill_tensor_with_memref_return
+ (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
+{
+ // expected-error @+1 {{expected type of operand #0 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
+ %0 = linalg.fill(%arg0, %arg1) : tensor<?x?xf32>, f32 -> memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index c4a3247fdc88..44743eaedc8c 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -805,3 +805,12 @@ func @legal_collapsing_reshape_dynamic_memref
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
// CHECK: func @legal_collapsing_reshape_dynamic_memref
// CHECK: linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+
+// -----
+
+func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> {
+ %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index f52d7fefa8be..f8b996e1ae05 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -41,7 +41,7 @@ func @generic_op_tensors(
%4 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2, d1)>,
- affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
+ affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%3 : tensor<?x?x?xf32>) {
@@ -88,7 +88,7 @@ func @indexed_generic_op_tensors(
%4 = linalg.indexed_generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2, d1)>,
- affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
+ affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%3 : tensor<?x?x?xf32>) {
@@ -120,3 +120,26 @@ func @indexed_generic_op_tensors(
// CHECK: scf.yield %[[TD1]]
// CHECK: }
// CHECK: return %[[TD0]]
+
+// -----
+
+func @fill_tensors(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> {
+ %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK: func @fill_tensors
+// CHECK: %[[INIT:.+]] = linalg.init_tensor
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG4:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[YIELD_1:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[FILL_TILE:.+]] = subtensor %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[RESULT_TILE:.+]] = linalg.fill(%[[FILL_TILE]], %{{.+}})
+// CHECK: %[[YIELD_2:.+]] = subtensor_insert %[[RESULT_TILE]]
+// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.yield %[[YIELD_2]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD_1]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
More information about the llvm-branch-commits
mailing list