[Mlir-commits] [mlir] 0e8c68c - [mlir][Interfaces] Fix DestinationStyleOpInterface for vector ops
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 17 08:40:33 PDT 2023
Author: Matthias Springer
Date: 2023-07-17T17:40:18+02:00
New Revision: 0e8c68c30102e91b992996c4f45e5bc03280923a
URL: https://github.com/llvm/llvm-project/commit/0e8c68c30102e91b992996c4f45e5bc03280923a
DIFF: https://github.com/llvm/llvm-project/commit/0e8c68c30102e91b992996c4f45e5bc03280923a.diff
LOG: [mlir][Interfaces] Fix DestinationStyleOpInterface for vector ops
This revision fixes `hasTensorSemantics` and `hasBufferSemantics` for vector transfer ops, which may have a vector operand. `VectorType` implements `ShapedType` and such operands do not affect whether an op has tensor or buffer semantics. Also implement `DestinationStyleOpInterface` on `TransferReadOp` so that `hasTensorSemantics`/`hasBufferSemantics` can be called. (The op has no inits, but this makes it symmetric to `TransferWriteOp`.)
Differential Revision: https://reviews.llvm.org/D155469
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 555ed0bec3c9a6..27080e84d46c81 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1178,7 +1178,8 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- AttrSizedOperandSegments
+ AttrSizedOperandSegments,
+ DestinationStyleOpInterface
]>,
Arguments<(ins AnyShaped:$source,
Variadic<Index>:$indices,
@@ -1400,6 +1401,10 @@ def Vector_TransferReadOp :
let extraClassDeclaration = [{
// MaskableOpInterface methods.
bool supportsPassthru() { return true; }
+
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ return {0, 0}; // empty range (no init operands)
+ }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
index ab12f074203d3e..ff2da985c53d7c 100644
--- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
+++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
@@ -24,15 +24,15 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
position [start, end). The positions are defined by getDpsInitsPositionRange
method.
- If the op has "tensor semantics", then the input operands are either scalars
- or ranked tensors. The init operands are ranked tensors and every tensor
- init is tied to a corresponding tensor OpResult in a 1-to-1 fashion.
- The i-th init tensor is tied to the i-th OpResult. The op may not have any
- additional OpResults. Init operands and their tied OpResults have the same
- type.
+ If the op has "tensor semantics", then the input operands are either ranked
+ tensors or other non-tensor/memref types ("scalars"). The init operands are
+ ranked tensors and every tensor init is tied to a corresponding tensor
+ OpResult in a 1-to-1 fashion. The i-th init tensor is tied to the i-th
+ OpResult. The op may not have any additional OpResults. Init operands and
+ their tied OpResults have the same type.
If the op has "buffer semantics", then the input operands are either ranked
- memrefs or other non-tensor types, e.g. scalar types. Furthermore, the
+ memrefs or other non-tensor/memref types ("scalar" types). Furthermore, the
init operands are ranked memrefs and the op has no results.
Destination-passing style abstraction makes certain transformations easier.
@@ -194,14 +194,17 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
}]
>,
InterfaceMethod<
- /*desc=*/"Return true if the `opOperand` is a scalar value.",
+ /*desc=*/[{
+ Return true if the `opOperand` is a scalar value. A scalar is defined
+ as neither a memref nor a tensor value.
+ }],
/*retTy=*/"bool",
/*methodName=*/"isScalar",
/*args=*/(ins "::mlir::OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == $_op.getOperation());
- return !::llvm::isa<ShapedType>(opOperand->get().getType());
+ return !::llvm::isa<MemRefType, TensorType>(opOperand->get().getType());
}]
>,
InterfaceMethod<
@@ -235,32 +238,49 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
- /*desc=*/"Return whether the op has only ranked MemRef input/inits.",
+ /*desc=*/[{
+ Return whether the op has buffer semantics. That is the case if the op
+ has no tensor operands and at least one memref operand.
+ }],
/*retTy=*/"bool",
/*methodName=*/"hasBufferSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op->getNumResults() == 0 &&
- ::llvm::all_of($_op->getOpOperands(),
- [&](::mlir::OpOperand &opOperand) {
- return isScalar(&opOperand) ||
- ::llvm::isa<::mlir::MemRefType>(opOperand.get().getType());
- });
+ // No tensors.
+ auto isTensor = [](Value v){
+ return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
+ };
+ if (::llvm::any_of($_op->getOperands(), isTensor))
+ return false;
+ // At least one memref.
+ auto isMemref = [](Value v){
+ return ::llvm::isa<::mlir::MemRefType>(v.getType());
+ };
+ return llvm::any_of($_op->getOperands(), isMemref);
}]
>,
InterfaceMethod<
- /*desc=*/"Return whether the op has only ranked tensor inputs/inits.",
+ /*desc=*/[{
+ Return whether the op has tensor semantics. That is the case if the op
+ has no memref operands and at least one tensor operand.
+ }],
/*retTy=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return ::llvm::all_of($_op->getOpOperands(),
- [&](::mlir::OpOperand &opOperand) {
- return isScalar(&opOperand) ||
- ::llvm::isa<::mlir::RankedTensorType>(opOperand.get().getType());
- });
+ // No memrefs.
+ auto isMemref = [](Value v){
+ return ::llvm::isa<::mlir::MemRefType>(v.getType());
+ };
+ if (::llvm::any_of($_op->getOperands(), isMemref))
+ return false;
+ // At least one tensor.
+ auto isTensor = [](Value v){
+ return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
+ };
+ return llvm::any_of($_op->getOperands(), isTensor);
}]
>
];
More information about the Mlir-commits
mailing list