[Mlir-commits] [mlir] f4ae02a - [mlir][linalg] Add a FillOpInterface.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 8 07:48:34 PST 2022
Author: gysit
Date: 2022-03-08T15:48:02Z
New Revision: f4ae02afe7c40890977f4d7222761876ce9475f8
URL: https://github.com/llvm/llvm-project/commit/f4ae02afe7c40890977f4d7222761876ce9475f8
DIFF: https://github.com/llvm/llvm-project/commit/f4ae02afe7c40890977f4d7222761876ce9475f8.diff
LOG: [mlir][linalg] Add a FillOpInterface.
Add a FillOpInterface similar to the contraction and convolution op interfaces. The FillOpInterface is a preparation step to replace linalg.fill by its OpDSL version linalg.fill_tensor. The interface implements the `value()`, `output()`, and `result()` methods that by default are not available on linalg.fill_tensor.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D120725
Added:
mlir/test/Dialect/Linalg/fill-interface-invalid.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index df4af026d0c57..b0c705a87a35d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -43,6 +43,9 @@ LogicalResult verifyContractionInterface(Operation *op);
/// Verify that `op` conforms to the ConvolutionOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op);
+/// Verify that `op` conforms to the FillOpInterface.
+LogicalResult verifyFillInterface(Operation *op);
+
/// Verify that `op` conforms to the invariants of StructuredOpInterface
LogicalResult verifyStructuredOpInterface(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index dbf65aec97880..4ac6bb78653c6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -132,6 +132,50 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
];
}
+def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
+ let description = [{
+ A fill operation is defined in general terms:
+ 1. Has a scalar `value` operand.
+ 2. Has one `output` operand.
+ }];
+ let cppNamespace = "::mlir::linalg";
+ let verify = [{ return detail::verifyFillInterface($_op); }];
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/"Return the fill value.",
+ /*retTy=*/"Value",
+ /*methodName=*/"value",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getOperation()->getOperand(0);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the output operand.",
+ /*retTy=*/"Value",
+ /*methodName=*/"output",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getOperation()->getOperand(1);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the result.",
+ /*retTy=*/"Value",
+ /*methodName=*/"result",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if ($_op.getOperation()->getResults().empty())
+ return nullptr;
+ return $_op.getOperation()->getResults().front();
+ }]
+ >,
+ ];
+}
+
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index e296004603673..7511e268ae850 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2875,6 +2875,8 @@ metadata: !LinalgOpMetadata
Works for arbitrary ranked output tensors since the operation performs scalar
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
+ implements:
+ - LinalgFillOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 84e26b150fa32..4c796723c25a7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -408,6 +408,44 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
}
return success();
}
+
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+
+enum class MatchFillResult {
+ Success = 0,
+ NotLinalgOp,
+ WrongNumOperands,
+ NotScalarInput
+};
+
+static MatchFillResult isFillInterfaceImpl(Operation *op) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ if (!linalgOp)
+ return MatchFillResult::NotLinalgOp;
+ if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1)
+ return MatchFillResult::WrongNumOperands;
+
+ OpOperand *value = linalgOp.getInputOperand(0);
+ if (!linalgOp.isScalar(value))
+ return MatchFillResult::NotScalarInput;
+
+ return MatchFillResult::Success;
+}
+
+LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
+ auto res = isFillInterfaceImpl(op);
+ if (res == MatchFillResult::NotLinalgOp)
+ return op->emitError("expected a LinalgOp");
+ if (res == MatchFillResult::WrongNumOperands)
+ return op->emitError("expected op with 1 input and 1 output");
+ if (res == MatchFillResult::NotScalarInput)
+ return op->emitError("expected op with scalar input");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// StructuredOpInterface implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 7de0a76e87b7c..1de5449e27e31 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -686,6 +686,7 @@ def __init__(self, cpp_name: str):
ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
+FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
class OpMetadataDef(YAMLObject):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 0ef40613a7ba9..7798d7f9498e3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -671,6 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
"""
+ implements(FillOpInterface)
O[None] = TypeFn.cast_signed(U, value)
diff --git a/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir
new file mode 100644
index 0000000000000..17a5f119cfd50
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+func @test_fill_op_not_linalg_op(%arg0 : f32, %arg1 : tensor<?xf32>)
+ -> tensor<?xf32> {
+ // expected-error @+1 {{expected a LinalgOp}}
+ %0 = "test.fill_op_not_linalg_op"(%arg0, %arg1)
+ : (f32, tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> ()>
+#map1 = affine_map<(d0) -> (d0)>
+func @test_fill_op_wrong_num_operands(%arg0 : f32, %arg1 : tensor<?xf32>)
+ -> tensor<?xf32> {
+ // expected-error @+1 {{expected op with 1 input and 1 output}}
+ %0 = test.linalg_fill_op {
+ indexing_maps = [#map0, #map0, #map1],
+ iterator_types = ["parallel"]}
+ ins(%arg0, %arg0 : f32, f32) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+ linalg.yield %arg2 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+#map1 = affine_map<(d0) -> (d0)>
+func @test_fill_op_non_scalar_input(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{expected op with scalar input}}
+ %0 = test.linalg_fill_op {
+ indexing_maps = [#map1, #map1],
+ iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ab7d9e62e67c..68139eb555338 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2640,6 +2640,64 @@ def TestLinalgConvOp :
}];
}
+//===----------------------------------------------------------------------===//
+// Test LinalgFillOpInterface.
+//===----------------------------------------------------------------------===//
+
+def TestLinalgFillOpNotLinalgOp : TEST_Op<"fill_op_not_linalg_op", [
+ LinalgFillOpInterface]> {
+ let arguments = (ins
+ AnyType:$value, AnyType:$output);
+ let results = (outs AnyRankedTensor:$result);
+}
+
+def TestLinalgFillOp :
+ TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock,
+ LinalgStructuredInterface, LinalgFillOpInterface]> {
+
+ let arguments = (ins Variadic<AnyType>:$inputs,
+ Variadic<AnyType>:$outputs);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region AnyRegion:$region);
+
+ let assemblyFormat = [{
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ `outs` `(` $outputs `:` type($outputs) `)`
+ $region (`->` type($results)^)?
+ }];
+
+ let extraClassDeclaration = [{
+ bool hasIndexSemantics() { return false; }
+
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute> attrs) {
+ b.create<mlir::linalg::YieldOp>(block.getArguments().back());
+ }
+
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return ®ionBuilder;
+ }
+
+ mlir::ArrayAttr iterator_types() {
+ return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+ }
+
+ mlir::ArrayAttr indexing_maps() {
+ return getOperation()->getAttrOfType<mlir::ArrayAttr>("indexing_maps");
+ }
+
+ std::string getLibraryCallName() {
+ return "";
+ }
+
+ // To conform with interface requirement on operand naming.
+ mlir::ValueRange inputs() { return getInputs(); }
+ mlir::ValueRange outputs() { return getOutputs(); }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Test Ops with Default-Valued String Attributes
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list