[llvm-branch-commits] [mlir] 6e8ef3b - [mlir][Linalg] Make Fill operation work on tensors.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 22 14:44:22 PST 2021


Author: MaheshRavishankar
Date: 2021-01-22T14:39:27-08:00
New Revision: 6e8ef3b76ab65960edd6ee99f387e75564d8d9db

URL: https://github.com/llvm/llvm-project/commit/6e8ef3b76ab65960edd6ee99f387e75564d8d9db
DIFF: https://github.com/llvm/llvm-project/commit/6e8ef3b76ab65960edd6ee99f387e75564d8d9db.diff

LOG: [mlir][Linalg] Make Fill operation work on tensors.

Depends on D95109

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 26db4c2f6735..436dab1ade2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -148,8 +148,9 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
 }
 
 def FillOp : LinalgStructured_Op<"fill", []> {
-  let arguments = (ins AnyStridedMemRef:$output,
+  let arguments = (ins AnyShaped:$output,
                    AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
+  let results = (outs Optional<AnyRankedTensor>:$result);
   let extraClassDeclaration = libraryCallName # [{
     ValueRange inputs() { return {}; }
     ValueRange outputs() { return getOperands().take_front(); }
@@ -174,6 +175,14 @@ def FillOp : LinalgStructured_Op<"fill", []> {
     }
   }];
 
+  let assemblyFormat = [{
+    `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
+  }];
+
+  let builders = [
+    OpBuilderDAG<(ins "Value":$output, "Value":$value)>
+  ];
+
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b500eefa9d0c..a6f3576c4240 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -220,6 +220,16 @@ static LogicalResult foldMemRefCast(Operation *op) {
 // LinalgOps.td), we define an overloaded `print` function and a
 // parse`className` function.
 
+//===----------------------------------------------------------------------===//
+// FillOp
+//===----------------------------------------------------------------------===//
+
+void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
+                   Value value) {
+  build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
+        value);
+}
+
 //===----------------------------------------------------------------------===//
 // GenericOps
 //===----------------------------------------------------------------------===//
@@ -1726,6 +1736,10 @@ static LogicalResult verify(FillOp op) {
   auto fillType = op.value().getType();
   if (viewType.getElementType() != fillType)
     return op.emitOpError("expects fill type to match view elemental type");
+  if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+    return op.emitOpError(
+        "expected fill op with no result value to use memref type");
+  }
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a3ef242c29f9..6579add14c50 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -659,3 +659,41 @@ func @pad_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
   } : tensor<?x4xi32> to tensor<?x9xi32>
   return %0 : tensor<?x9xi32>
 }
+
+// -----
+
+func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
+{
+  %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+  // expected-error @+1 {{expected fill op with no result value to use memref type}}
+  linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32
+}
+
+// -----
+
+func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
+{
+  // expected-error @+1 {{unexpected #results > #outputs}}
+  %0 = linalg.fill(%arg0, %arg1) : memref<?x?xf32>, f32 -> memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @illegal_fill_memref_with_tensor_return
+  (%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
+{
+  // expected-error @+1 {{unexpected #results > #outputs}}
+  %0 = linalg.fill(%arg0, %arg1) : memref<?x?xf32>, f32 -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @illegal_fill_tensor_with_memref_return
+  (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
+{
+  // expected-error @+1 {{expected type of operand #0 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
+  %0 = linalg.fill(%arg0, %arg1) : tensor<?x?xf32>, f32 -> memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index c4a3247fdc88..44743eaedc8c 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -805,3 +805,12 @@ func @legal_collapsing_reshape_dynamic_memref
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
 //     CHECK: func @legal_collapsing_reshape_dynamic_memref
 //     CHECK:   linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+
+// -----
+
+func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> {
+  %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+  %1 = linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index f52d7fefa8be..f8b996e1ae05 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -41,7 +41,7 @@ func @generic_op_tensors(
   %4 = linalg.generic
     {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                       affine_map<(d0, d1, d2) -> (d0, d2, d1)>,
-		      affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
+                      affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
      iterator_types = ["parallel", "parallel", "parallel"]}
     ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
     outs(%3 : tensor<?x?x?xf32>) {
@@ -88,7 +88,7 @@ func @indexed_generic_op_tensors(
   %4 = linalg.indexed_generic
     {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                       affine_map<(d0, d1, d2) -> (d0, d2, d1)>,
-		      affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
+                      affine_map<(d0, d1, d2) -> (d2, d1, d0)>],
      iterator_types = ["parallel", "parallel", "parallel"]}
     ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
     outs(%3 : tensor<?x?x?xf32>) {
@@ -120,3 +120,26 @@ func @indexed_generic_op_tensors(
 //       CHECK:   scf.yield %[[TD1]]
 //       CHECK: }
 //       CHECK: return %[[TD0]]
+
+// -----
+
+func @fill_tensors(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> {
+  %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+  %1 = linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+//       CHECK: func @fill_tensors
+//       CHECK:   %[[INIT:.+]] = linalg.init_tensor
+//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-z0-9_]+]]
+//  CHECK-SAME:     iter_args(%[[ARG4:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+//       CHECK:     %[[YIELD_1:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+//  CHECK-SAME:       iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<?x?xf32>) {
+//       CHECK:       %[[FILL_TILE:.+]] = subtensor %[[ARG6]][%[[IV0]], %[[IV1]]]
+//       CHECK:       %[[RESULT_TILE:.+]] = linalg.fill(%[[FILL_TILE]], %{{.+}})
+//       CHECK:       %[[YIELD_2:.+]] = subtensor_insert %[[RESULT_TILE]]
+//  CHECK-SAME:         into %[[ARG6]][%[[IV0]], %[[IV1]]]
+//       CHECK:       scf.yield %[[YIELD_2]]
+//       CHECK:     }
+//       CHECK:     scf.yield %[[YIELD_1]]
+//       CHECK:   }
+//       CHECK:   return %[[RESULT]]


        


More information about the llvm-branch-commits mailing list