[Mlir-commits] [mlir] 2bd6077 - DestinationPassingStyle: allow additional non-tensor results
Benoit Jacob
llvmlistbot at llvm.org
Fri May 12 08:35:38 PDT 2023
Author: Benoit Jacob
Date: 2023-05-12T15:35:24Z
New Revision: 2bd6077d7fb61a54ebfbcf2fa210e07e8005c4bf
URL: https://github.com/llvm/llvm-project/commit/2bd6077d7fb61a54ebfbcf2fa210e07e8005c4bf
DIFF: https://github.com/llvm/llvm-project/commit/2bd6077d7fb61a54ebfbcf2fa210e07e8005c4bf.diff
LOG: DestinationPassingStyle: allow additional non-tensor results
Also some simplifications:
* `outputBufferOperands` was unused.
* The condition that the number of operands equals the number of inputs
plus the number of inits seemed vacuously true (?).
Differential Revision: https://reviews.llvm.org/D150376
Added:
mlir/test/Interfaces/DestinationStyleOpInterface/verify-destination-style-op-interface.mlir
Modified:
mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index a9bab23f1a72c..f344ea656b247 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -22,35 +22,38 @@ OpOperandVector::operator SmallVector<Value>() {
return result;
}
+namespace {
+size_t getNumTensorResults(Operation *op) {
+ size_t numTensorResults = 0;
+ for (auto t : op->getResultTypes()) {
+ if (isa<TensorType>(t)) {
+ ++numTensorResults;
+ }
+ }
+ return numTensorResults;
+}
+} // namespace
+
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
- SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
+ SmallVector<OpOperand *> outputTensorOperands;
for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) {
Type type = operand->get().getType();
- if (isa<MemRefType>(type)) {
- outputBufferOperands.push_back(operand);
- } else if (isa<RankedTensorType>(type)) {
+ if (isa<RankedTensorType>(type)) {
outputTensorOperands.push_back(operand);
- } else {
+ } else if (!isa<MemRefType>(type)) {
return op->emitOpError("expected that operand #")
<< operand->getOperandNumber()
<< " is a ranked tensor or a ranked memref";
}
}
- // Expect at least one output operand.
- int64_t numInputs = dstStyleOp.getNumDpsInputs();
- int64_t numInits = dstStyleOp.getNumDpsInits();
- if (numInits == 0)
- return op->emitOpError("expected at least one output operand");
- if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits)))
- return failure();
- // Verify the number of results matches the number of output tensors.
- if (op->getNumResults() != outputTensorOperands.size())
- return op->emitOpError("expected the number of results (")
- << op->getNumResults()
+ // Verify the number of tensor results matches the number of output tensors.
+ if (getNumTensorResults(op) != outputTensorOperands.size())
+ return op->emitOpError("expected the number of tensor results (")
+ << getNumTensorResults(op)
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 475348f3cb176..dbc93d56e2a9e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -326,7 +326,7 @@ func.func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
func.func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
{
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
- // expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}}
+ // expected-error @+1 {{expected the number of tensor results (0) to be equal to the number of output tensors (1)}}
linalg.fill ins(%arg2 : f32) outs(%0 : tensor<?x?xf32>)
}
@@ -335,7 +335,7 @@ func.func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f
func.func @illegal_fill_memref_with_tensor_return
(%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
{
- // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}}
+ // expected-error @+1 {{expected the number of tensor results (1) to be equal to the number of output tensors (0)}}
%0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
diff --git a/mlir/test/Interfaces/DestinationStyleOpInterface/verify-destination-style-op-interface.mlir b/mlir/test/Interfaces/DestinationStyleOpInterface/verify-destination-style-op-interface.mlir
new file mode 100644
index 0000000000000..82ec924477359
--- /dev/null
+++ b/mlir/test/Interfaces/DestinationStyleOpInterface/verify-destination-style-op-interface.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func.func @ins_1_index_outs_none_results_1_index(%arg0 : index) -> index {
+ %0 = test.destination_style_op ins(%arg0 : index) -> index
+ func.return %0 : index
+}
+
+// -----
+
+func.func @ins_1_index_outs_1_tensor_results_1_index(%arg0 : index, %arg1 : tensor<2x2xf32>) -> index {
+ // expected-error @+1 {{op expected the number of tensor results (0) to be equal to the number of output tensors (1)}}
+ %0 = test.destination_style_op ins(%arg0 : index) outs(%arg1 : tensor<2x2xf32>) -> index
+ func.return %0 : index
+}
+
+// -----
+
+func.func @ins_1_tensor_outs_none_results_1_index(%arg0 :tensor<2x2xf32>) -> index {
+ %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) -> index
+ func.return %0 : index
+}
+
+// -----
+
+func.func @ins_1_tensor_outs_1_tensor_results_1_index(%arg0 :tensor<2x2xf32>, %arg1 : tensor<2x2xf32>) -> index {
+ // expected-error @+1 {{op expected the number of tensor results (0) to be equal to the number of output tensors (1)}}
+ %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) outs(%arg1 : tensor<2x2xf32>) -> index
+ func.return %0 : index
+}
+
+// -----
+
+func.func @ins_1_index_outs_none_results_1_tensor(%arg0 : index) -> tensor<2x2xf32> {
+ // expected-error @+1 {{op expected the number of tensor results (1) to be equal to the number of output tensors (0)}}
+ %0 = test.destination_style_op ins(%arg0 : index) -> tensor<2x2xf32>
+ func.return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @ins_1_index_outs_1_tensor_results_1_tensor(%arg0 : index, %arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> {
+ %0 = test.destination_style_op ins(%arg0 : index) outs(%arg1 : tensor<2x2xf32>) -> tensor<2x2xf32>
+ func.return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @ins_1_tensor_outs_none_results_1_tensor(%arg0 :tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // expected-error @+1 {{op expected the number of tensor results (1) to be equal to the number of output tensors (0)}}
+ %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32>
+ func.return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+func.func @ins_1_tensor_outs_1_tensor_results_1_tensor(%arg0 :tensor<2x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> {
+ %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) outs(%arg1 : tensor<2x2xf32>) -> tensor<2x2xf32>
+ func.return %0 : tensor<2x2xf32>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a108f278c1ffa..dd9e62b03b89d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2908,6 +2908,34 @@ def OpCrashShort : TEST_Op<"op_crash_short"> {
def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
+//===----------------------------------------------------------------------===//
+// Test DestinationStyleOpInterface.
+//===----------------------------------------------------------------------===//
+
+def TestDestinationStyleOp :
+ TEST_Op<"destination_style_op", [
+ DestinationStyleOpInterface,
+ AttrSizedOperandSegments]> {
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyType>:$outputs,
+ Variadic<AnyType>:$other_operands);
+ let results = (outs Variadic<AnyType>:$results);
+ let assemblyFormat = [{
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ (`outs` `(` $outputs^ `:` type($outputs) `)`)?
+ (`(` $other_operands^ `:` type($other_operands) `)`)?
+ (`->` type($results)^)?
+ }];
+
+ let extraClassDeclaration = [{
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ int64_t numOperands = this->getNumOperands();
+ return {numOperands - getOutputs().size(), numOperands};
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Test LinalgConvolutionOpInterface.
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list