[Mlir-commits] [mlir] f4ae02a - [mlir][linalg] Add a FillOpInterface.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 8 07:48:34 PST 2022


Author: gysit
Date: 2022-03-08T15:48:02Z
New Revision: f4ae02afe7c40890977f4d7222761876ce9475f8

URL: https://github.com/llvm/llvm-project/commit/f4ae02afe7c40890977f4d7222761876ce9475f8
DIFF: https://github.com/llvm/llvm-project/commit/f4ae02afe7c40890977f4d7222761876ce9475f8.diff

LOG: [mlir][linalg] Add a FillOpInterface.

Add a FillOpInterface similar to the contraction and convolution op interfaces. The FillOpInterface is a preparation step to replace linalg.fill by its OpDSL version linalg.fill_tensor. The interface implements the `value()`, `output()`, and `result()` methods that by default are not available on linalg.fill_tensor.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D120725

Added: 
    mlir/test/Dialect/Linalg/fill-interface-invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index df4af026d0c57..b0c705a87a35d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -43,6 +43,9 @@ LogicalResult verifyContractionInterface(Operation *op);
 /// Verify that `op` conforms to the ConvolutionOpInterface.
 LogicalResult verifyConvolutionInterface(Operation *op);
 
+/// Verify that `op` conforms to the FillOpInterface.
+LogicalResult verifyFillInterface(Operation *op);
+
 /// Verify that `op` conforms to the invariants of StructuredOpInterface
 LogicalResult verifyStructuredOpInterface(Operation *op);
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index dbf65aec97880..4ac6bb78653c6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -132,6 +132,50 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
   ];
 }
 
+def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
+  let description = [{
+    A fill operation is defined in general terms:
+    1. Has a scalar `value` operand.
+    2. Has one `output` operand.
+  }];
+  let cppNamespace = "::mlir::linalg";
+  let verify = [{ return detail::verifyFillInterface($_op); }];
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/"Return the fill value.",
+      /*retTy=*/"Value",
+      /*methodName=*/"value",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getOperation()->getOperand(0);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the output operand.",
+      /*retTy=*/"Value",
+      /*methodName=*/"output",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getOperation()->getOperand(1);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the result.",
+      /*retTy=*/"Value",
+      /*methodName=*/"result",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        if ($_op.getOperation()->getResults().empty())
+          return nullptr;
+        return $_op.getOperation()->getResults().front();
+      }]
+    >,
+  ];
+}
+
 // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
 def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
   let cppNamespace = "::mlir::linalg";

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index e296004603673..7511e268ae850 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2875,6 +2875,8 @@ metadata: !LinalgOpMetadata
     Works for arbitrary ranked output tensors since the operation performs scalar
     accesses only and is thus rank polymorphic. Numeric casting is performed on
     the value operand, promoting it to the same data type as the output.
+  implements:
+  - LinalgFillOpInterface
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 84e26b150fa32..4c796723c25a7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -408,6 +408,44 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+
+enum class MatchFillResult {
+  Success = 0,
+  NotLinalgOp,
+  WrongNumOperands,
+  NotScalarInput
+};
+
+static MatchFillResult isFillInterfaceImpl(Operation *op) {
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+  if (!linalgOp)
+    return MatchFillResult::NotLinalgOp;
+  if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1)
+    return MatchFillResult::WrongNumOperands;
+
+  OpOperand *value = linalgOp.getInputOperand(0);
+  if (!linalgOp.isScalar(value))
+    return MatchFillResult::NotScalarInput;
+
+  return MatchFillResult::Success;
+}
+
+LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
+  auto res = isFillInterfaceImpl(op);
+  if (res == MatchFillResult::NotLinalgOp)
+    return op->emitError("expected a LinalgOp");
+  if (res == MatchFillResult::WrongNumOperands)
+    return op->emitError("expected op with 1 input and 1 output");
+  if (res == MatchFillResult::NotScalarInput)
+    return op->emitError("expected op with scalar input");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // StructuredOpInterface implementation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 7de0a76e87b7c..1de5449e27e31 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -686,6 +686,7 @@ def __init__(self, cpp_name: str):
 
 ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
 ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
+FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
 
 
 class OpMetadataDef(YAMLObject):

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 0ef40613a7ba9..7798d7f9498e3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -671,6 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
   accesses only and is thus rank polymorphic. Numeric casting is performed on
   the value operand, promoting it to the same data type as the output.
   """
+  implements(FillOpInterface)
   O[None] = TypeFn.cast_signed(U, value)
 
 

diff  --git a/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir
new file mode 100644
index 0000000000000..17a5f119cfd50
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+func @test_fill_op_not_linalg_op(%arg0 : f32, %arg1 : tensor<?xf32>)
+     -> tensor<?xf32> {
+  // expected-error @+1 {{expected a LinalgOp}}
+  %0 = "test.fill_op_not_linalg_op"(%arg0, %arg1)
+      : (f32, tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> ()>
+#map1 = affine_map<(d0) -> (d0)>
+func @test_fill_op_wrong_num_operands(%arg0 : f32, %arg1 : tensor<?xf32>)
+     -> tensor<?xf32> {
+  // expected-error @+1 {{expected op with 1 input and 1 output}}
+  %0 = test.linalg_fill_op {
+      indexing_maps = [#map0, #map0, #map1],
+      iterator_types = ["parallel"]}
+      ins(%arg0, %arg0 : f32, f32) outs(%arg1 : tensor<?xf32>) {
+      ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+         linalg.yield  %arg2 : f32
+      } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+#map1 = affine_map<(d0) -> (d0)>
+func @test_fill_op_non_scalar_input(%arg0 : tensor<?xf32>,
+    %arg1 : tensor<?xf32>) -> tensor<?xf32> {
+  // expected-error @+1 {{expected op with scalar input}}
+  %0 = test.linalg_fill_op {
+      indexing_maps = [#map1, #map1],
+      iterator_types = ["parallel"]}
+      ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
+      ^bb0(%arg2 : f32, %arg3 : f32):
+         linalg.yield  %arg2 : f32
+      } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ab7d9e62e67c..68139eb555338 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2640,6 +2640,64 @@ def TestLinalgConvOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test LinalgFillOpInterface.
+//===----------------------------------------------------------------------===//
+
+def TestLinalgFillOpNotLinalgOp : TEST_Op<"fill_op_not_linalg_op", [
+    LinalgFillOpInterface]> {
+  let arguments = (ins
+    AnyType:$value, AnyType:$output);
+  let results = (outs AnyRankedTensor:$result);
+}
+
+def TestLinalgFillOp :
+  TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock,
+      LinalgStructuredInterface, LinalgFillOpInterface]> {
+
+  let arguments = (ins Variadic<AnyType>:$inputs,
+    Variadic<AnyType>:$outputs);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region AnyRegion:$region);
+
+  let assemblyFormat = [{
+    attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+    `outs` `(` $outputs `:` type($outputs) `)`
+    $region (`->` type($results)^)?
+  }];
+
+  let extraClassDeclaration = [{
+    bool hasIndexSemantics() { return false; }
+
+    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+                              mlir::ArrayRef<mlir::NamedAttribute> attrs) {
+      b.create<mlir::linalg::YieldOp>(block.getArguments().back());
+    }
+
+    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                              mlir::ArrayRef<mlir::NamedAttribute>)>
+    getRegionBuilder() {
+      return ®ionBuilder;
+    }
+
+    mlir::ArrayAttr iterator_types() {
+      return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+    }
+
+    mlir::ArrayAttr indexing_maps() {
+      return getOperation()->getAttrOfType<mlir::ArrayAttr>("indexing_maps");
+    }
+
+    std::string getLibraryCallName() {
+      return "";
+    }
+
+    // To conform with interface requirement on operand naming.
+    mlir::ValueRange inputs() { return getInputs(); }
+    mlir::ValueRange outputs() { return getOutputs(); }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test Ops with Default-Valued String Attributes
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list