[Mlir-commits] [mlir] 47cbd9f - [mlir][Vector] NFC - Improve VectorInterfaces
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jul 20 05:25:14 PDT 2020
Author: Nicolas Vasilache
Date: 2020-07-20T08:24:22-04:00
New Revision: 47cbd9f92282e3a19f161053cfbf77a7691de43e
URL: https://github.com/llvm/llvm-project/commit/47cbd9f92282e3a19f161053cfbf77a7691de43e
DIFF: https://github.com/llvm/llvm-project/commit/47cbd9f92282e3a19f161053cfbf77a7691de43e.diff
LOG: [mlir][Vector] NFC - Improve VectorInterfaces
This revision improves and makes better use of OpInterfaces for the Vector dialect.
Differential Revision: https://reviews.llvm.org/D84053
Added:
mlir/include/mlir/Interfaces/VectorInterfaces.h
mlir/include/mlir/Interfaces/VectorInterfaces.td
mlir/lib/Interfaces/VectorInterfaces.cpp
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Dialect/Vector/VectorUtils.h
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/StandardOps/CMakeLists.txt
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Dialect/Vector/VectorUtils.cpp
mlir/lib/Interfaces/CMakeLists.txt
Removed:
mlir/include/mlir/Interfaces/VectorUnrollInterface.h
mlir/include/mlir/Interfaces/VectorUnrollInterface.td
mlir/lib/Interfaces/VectorUnrollInterface.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 0f24d74dcac2..2500343c0af3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -21,7 +21,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
// Pull in all enum type definitions and utility function declarations.
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 702b912d3103..78307b897476 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -18,7 +18,7 @@ include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/VectorUnrollInterface.td"
+include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def StandardOps_Dialect : Dialect {
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 0f6aa66e926c..edf9557df389 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -19,7 +19,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
namespace mlir {
class MLIRContext;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 8880c288b648..10a4498b0bbd 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -15,7 +15,7 @@
include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/VectorUnrollInterface.td"
+include "mlir/Interfaces/VectorInterfaces.td"
def Vector_Dialect : Dialect {
let name = "vector";
@@ -905,34 +905,9 @@ def Vector_ExtractStridedSliceOp :
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
}
-def Vector_TransferOpUtils {
- code extraTransferDeclaration = [{
- static StringRef getMaskedAttrName() { return "masked"; }
- static StringRef getPermutationMapAttrName() { return "permutation_map"; }
- bool isMaskedDim(unsigned dim) {
- return !masked() ||
- masked()->cast<ArrayAttr>()[dim].cast<BoolAttr>().getValue();
- }
- MemRefType getMemRefType() {
- return memref().getType().cast<MemRefType>();
- }
- VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
- }
- // Number of dimensions that participate in the permutation map.
- unsigned getTransferRank() {
- return permutation_map().getNumResults();
- }
- // Number of leading dimensions that do not participate in the permutation
- // map.
- unsigned getLeadingMemRefRank() {
- return getMemRefType().getRank() - permutation_map().getNumResults();
- }
- }];
-}
-
def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
+ DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
@@ -1090,23 +1065,12 @@ def Vector_TransferReadOp :
"ArrayRef<bool> maybeMasked = {}">
];
- let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
- [{
- /// Build the default minor identity map suitable for a vector transfer.
- /// This also handles the case memref<... x vector<...>> -> vector<...> in
- /// which the rank of the identity map must take the vector element type
- /// into account.
- static AffineMap getTransferMinorIdentityMap(
- MemRefType memRefType, VectorType vectorType) {
- return impl::getTransferMinorIdentityMap(memRefType, vectorType);
- }
- }];
-
let hasFolder = 1;
}
def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
+ DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
@@ -1183,18 +1147,6 @@ def Vector_TransferWriteOp :
"Value memref, ValueRange indices, AffineMap permutationMap">,
];
- let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
- [{
- /// Build the default minor identity map suitable for a vector transfer.
- /// This also handles the case memref<... x vector<...>> -> vector<...> in
- /// which the rank of the identity map must take the vector element type
- /// into account.
- static AffineMap getTransferMinorIdentityMap(
- MemRefType memRefType, VectorType vectorType) {
- return impl::getTransferMinorIdentityMap(memRefType, vectorType);
- }
- }];
-
let hasFolder = 1;
}
diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 19f7f9538307..448004db32fa 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -153,6 +153,12 @@ AffineMap
makePermutationMap(Operation *op, ArrayRef<Value> indices,
const DenseMap<Operation *, unsigned> &loopToVectorDim);
+/// Build the default minor identity map suitable for a vector transfer. This
+/// also handles the case memref<... x vector<...>> -> vector<...> in which the
+/// rank of the identity map must take the vector element type into account.
+AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
+ VectorType vectorType);
+
namespace matcher {
/// Matches vector.transfer_read, vector.transfer_write and ops that return a
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 0de2b5a8688b..65e19f3eec1b 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,6 @@ add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(SideEffectInterfaces)
-add_mlir_interface(VectorUnrollInterface)
+add_mlir_interface(VectorInterfaces)
add_mlir_interface(ViewLikeInterface)
diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorInterfaces.h
similarity index 59%
rename from mlir/include/mlir/Interfaces/VectorUnrollInterface.h
rename to mlir/include/mlir/Interfaces/VectorInterfaces.h
index a68cc3411533..2134969e4020 100644
--- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.h
@@ -1,4 +1,4 @@
-//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===//
+//===- VectorInterfaces.h - Vector interfaces -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,18 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements the operation interface for vector ops that can be
-// unrolled.
+// This file implements the operation interfaces for vector ops.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
-#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+#ifndef MLIR_INTERFACES_VECTORINTERFACES_H
+#define MLIR_INTERFACES_VECTORINTERFACES_H
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
/// Include the generated interface declarations.
-#include "mlir/Interfaces/VectorUnrollInterface.h.inc"
+#include "mlir/Interfaces/VectorInterfaces.h.inc"
-#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+#endif // MLIR_INTERFACES_VECTORINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
new file mode 100644
index 000000000000..aefbb7d47117
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -0,0 +1,194 @@
+//===- VectorInterfaces.td - Vector interfaces -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for operations on vectors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_VECTORINTERFACES
+#define MLIR_INTERFACES_VECTORINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
+ let description = [{
+ Encodes properties of an operation on vectors that can be unrolled.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the shape ratio of unrolling to the target vector shape
+ `targetShape`. Return `None` if the op cannot be unrolled to the target
+ vector shape.
+ }],
+ /*retTy=*/"Optional<SmallVector<int64_t, 4>>",
+ /*methodName=*/"getShapeForUnroll",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert($_op.getOperation()->getNumResults() == 1);
+ auto vt = $_op.getResult().getType().
+ template dyn_cast<VectorType>();
+ if (!vt)
+ return None;
+ SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
+ return res;
+ }]
+ >,
+ ];
+}
+
+def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
+ let description = [{
+ Encodes properties of an operation on vectors that can be unrolled.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ StaticInterfaceMethod<
+ /*desc=*/"Return the `masked` attribute name.",
+ /*retTy=*/"StringRef",
+ /*methodName=*/"getMaskedAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/ [{ return "masked"; }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/"Return the `permutation_map` attribute name.",
+ /*retTy=*/"StringRef",
+ /*methodName=*/"getPermutationMapAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/ [{ return "permutation_map"; }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return `false` when the `masked` attribute at dimension
+ `dim` is set to `false`. Return `true` otherwise.}],
+ /*retTy=*/"bool",
+ /*methodName=*/"isMaskedDim",
+ /*args=*/(ins "unsigned":$dim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return !$_op.masked() ||
+ $_op.masked()->template cast<ArrayAttr>()[dim]
+ .template cast<BoolAttr>().getValue();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the memref operand.",
+ /*retTy=*/"Value",
+ /*methodName=*/"memref",
+ /*args=*/(ins),
+ /*methodBody=*/"return $_op.memref();"
+ /*defaultImplementation=*/
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the vector operand or result.",
+ /*retTy=*/"Value",
+ /*methodName=*/"vector",
+ /*args=*/(ins),
+ /*methodBody=*/"return $_op.vector();"
+ /*defaultImplementation=*/
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the indices operands.",
+ /*retTy=*/"ValueRange",
+ /*methodName=*/"indices",
+ /*args=*/(ins),
+ /*methodBody=*/"return $_op.indices();"
+ /*defaultImplementation=*/
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the permutation map.",
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"permutation_map",
+ /*args=*/(ins),
+ /*methodBody=*/"return $_op.permutation_map();"
+ /*defaultImplementation=*/
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the `masked` boolean ArrayAttr.",
+ /*retTy=*/"Optional<ArrayAttr>",
+ /*methodName=*/"masked",
+ /*args=*/(ins),
+ /*methodBody=*/"return $_op.masked();"
+ /*defaultImplementation=*/
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the MemRefType.",
+ /*retTy=*/"MemRefType",
+ /*methodName=*/"getMemRefType",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/
+ "return $_op.memref().getType().template cast<MemRefType>();"
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the VectorType.",
+ /*retTy=*/"VectorType",
+ /*methodName=*/"getVectorType",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/
+ "return $_op.vector().getType().template cast<VectorType>();"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Return the number of dimensions that participate in the
+ permutation map.}],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getTransferRank",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/
+ "return $_op.permutation_map().getNumResults();"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Return the number of leading memref dimensions that do not
+ participate in the permutation map.}],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getLeadingMemRefRank",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/
+ "return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Helper function to account for the fact that `permutationMap` results and
+ `op.indices` sizes may not match and may not be aligned. The first
+ `getLeadingMemRefRank()` indices may just be indexed and not transferred
+ from/into the vector.
+ For example:
+ ```
+ vector.transfer %0[%i, %j, %k, %c0] :
+ memref<?x?x?x?xf32>, vector<2x4xf32>
+ ```
+ with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`.
+ Provide a zip function to coiterate on 2 running indices: `resultIdx` and
+ `indicesIdx` which accounts for this misalignment.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"zipResultAndIndexing",
+ /*args=*/(ins "llvm::function_ref<void(int64_t, int64_t)>":$fun),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ for (int64_t resultIdx = 0,
+ indicesIdx = $_op.getLeadingMemRefRank(),
+ eResult = $_op.getTransferRank();
+ resultIdx < eResult;
+ ++resultIdx, ++indicesIdx)
+ fun(resultIdx, indicesIdx);
+ }]
+ >,
+ ];
+}
+
+#endif // MLIR_INTERFACES_VECTORINTERFACES
diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
deleted file mode 100644
index 166780b20e77..000000000000
--- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
+++ /dev/null
@@ -1,46 +0,0 @@
-//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines the interface for operations on vectors that can be unrolled.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE
-#define MLIR_INTERFACES_VECTORUNROLLINTERFACE
-
-include "mlir/IR/OpBase.td"
-
-def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
- let description = [{
- Encodes properties of an operation on vectors that can be unrolled.
- }];
- let cppNamespace = "::mlir";
-
- let methods = [
- InterfaceMethod<[{
- Returns the shape ratio of unrolling to the target vector shape
- `targetShape`. Returns `None` if the op cannot be unrolled to the target
- vector shape.
- }],
- "Optional<SmallVector<int64_t, 4>>",
- "getShapeForUnroll",
- (ins),
- /*methodBody=*/[{}],
- [{
- auto vt = this->getOperation()->getResult(0).getType().
- template dyn_cast<VectorType>();
- if (!vt)
- return None;
- SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
- return res;
- }]
- >,
- ];
-}
-
-#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index d0529668b2ee..ea368c9eb14e 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -249,8 +249,8 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
indexing.append(minorOffsets.begin(), minorOffsets.end());
Value memref = xferOp.memref();
- auto map = TransferReadOp::getTransferMinorIdentityMap(
- xferOp.getMemRefType(), minorVectorType);
+ auto map =
+ getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
ArrayAttr masked;
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
@@ -353,8 +353,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
result = vector_extract(xferOp.vector(), majorIvs);
else
result = std_load(alloc, majorIvs);
- auto map = TransferWriteOp::getTransferMinorIdentityMap(
- xferOp.getMemRefType(), minorVectorType);
+ auto map =
+ getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
ArrayAttr masked;
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index cb7540b46cf8..180fe069b681 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -82,8 +82,8 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
/// Return true if we can prove that the transfer operations access dijoint
/// memory.
-template <typename TransferTypeA, typename TransferTypeB>
-static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
+static bool isDisjoint(VectorTransferOpInterface transferA,
+ VectorTransferOpInterface transferB) {
if (transferA.memref() != transferB.memref())
return false;
// For simplicity only look at transfer of same type.
@@ -91,8 +91,8 @@ static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
return false;
unsigned rankOffset = transferA.getLeadingMemRefRank();
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
- auto indexA = transferA.indices()[i].template getDefiningOp<ConstantOp>();
- auto indexB = transferB.indices()[i].template getDefiningOp<ConstantOp>();
+ auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
+ auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();
// If any of the indices are dynamic we cannot prove anything.
if (!indexA || !indexB)
continue;
@@ -100,15 +100,15 @@ static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
if (i < rankOffset) {
// For dimension used as index if we can prove that index are
diff erent we
// know we are accessing disjoint slices.
- if (indexA.getValue().template cast<IntegerAttr>().getInt() !=
- indexB.getValue().template cast<IntegerAttr>().getInt())
+ if (indexA.getValue().cast<IntegerAttr>().getInt() !=
+ indexB.getValue().cast<IntegerAttr>().getInt())
return true;
} else {
// For this dimension, we slice a part of the memref we need to make sure
// the intervals accessed don't overlap.
int64_t distance =
- std::abs(indexA.getValue().template cast<IntegerAttr>().getInt() -
- indexB.getValue().template cast<IntegerAttr>().getInt());
+ std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
+ indexB.getValue().cast<IntegerAttr>().getInt());
if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
return true;
}
@@ -185,11 +185,17 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
continue;
if (auto transferWriteUse =
dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
- if (!isDisjoint(transferWrite, transferWriteUse))
+ if (!isDisjoint(
+ cast<VectorTransferOpInterface>(transferWrite.getOperation()),
+ cast<VectorTransferOpInterface>(
+ transferWriteUse.getOperation())))
return WalkResult::advance();
} else if (auto transferReadUse =
dyn_cast<vector::TransferReadOp>(use.getOwner())) {
- if (!isDisjoint(transferWrite, transferReadUse))
+ if (!isDisjoint(
+ cast<VectorTransferOpInterface>(transferWrite.getOperation()),
+ cast<VectorTransferOpInterface>(
+ transferReadUse.getOperation())))
return WalkResult::advance();
} else {
// Unknown use, we cannot prove that it doesn't alias with the
diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index 7d61aea3116e..06284f5d1daa 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -15,7 +15,7 @@ add_mlir_dialect_library(MLIRStandardOps
MLIREDSC
MLIRIR
MLIRSideEffectInterfaces
- MLIRVectorUnrollInterface
+ MLIRVectorInterfaces
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 69a329917228..d6ba987e6622 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -19,5 +19,5 @@ add_mlir_dialect_library(MLIRVector
MLIRSCF
MLIRLoopAnalysis
MLIRSideEffectInterfaces
- MLIRVectorUnrollInterface
+ MLIRVectorInterfaces
)
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5e01fa26f32e..03c4079ef171 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1466,22 +1466,6 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
// TransferReadOp
//===----------------------------------------------------------------------===//
-/// Build the default minor identity map suitable for a vector transfer. This
-/// also handles the case memref<... x vector<...>> -> vector<...> in which the
-/// rank of the identity map must take the vector element type into account.
-AffineMap
-mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
- VectorType vectorType) {
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- memRefType.getElementType().dyn_cast<VectorType>();
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
- return AffineMap::getMinorIdentityMap(
- memRefType.getRank(), vectorType.getRank() - elementVectorRank,
- memRefType.getContext());
-}
-
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
@@ -1600,11 +1584,10 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
}
-template <typename TransferOp>
-static void printTransferAttrs(OpAsmPrinter &p, TransferOp op) {
+static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 2> elidedAttrs;
- if (op.permutation_map() == TransferOp::getTransferMinorIdentityMap(
- op.getMemRefType(), op.getVectorType()))
+ if (op.permutation_map() ==
+ getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
elidedAttrs.push_back(op.getPermutationMapAttrName());
bool elideMasked = true;
if (auto maybeMasked = op.masked()) {
@@ -1623,7 +1606,7 @@ static void printTransferAttrs(OpAsmPrinter &p, TransferOp op) {
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
<< "], " << op.padding();
- printTransferAttrs(p, op);
+ printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getMemRefType() << ", " << op.getVectorType();
}
@@ -1653,8 +1636,7 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
- auto permMap =
- TransferReadOp::getTransferMinorIdentityMap(memRefType, vectorType);
+ auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
@@ -1733,6 +1715,7 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
+
return cstOp.getValue() + vectorSize <= memrefSize;
}
@@ -1744,23 +1727,11 @@ static LogicalResult foldTransferMaskAttribute(TransferOp op) {
bool changed = false;
SmallVector<bool, 4> isMasked;
isMasked.reserve(op.getTransferRank());
- // `permutationMap` results and `op.indices` sizes may not match and may not
- // be aligned. The first `indicesIdx` may just be indexed and not transferred
- // from/into the vector.
- // For example:
- // vector.transfer %0[%i, %j, %k, %c0] : memref<?x?x?x?xf32>, vector<2x4xf32>
- // with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`.
- // The `permutationMap` results and `op.indices` are however aligned when
- // iterating in reverse until we exhaust `permutationMap` results.
- // As a consequence we iterate with 2 running indices: `resultIdx` and
- // `indicesIdx`, until `resultIdx` reaches 0.
- for (int64_t resultIdx = permutationMap.getNumResults() - 1,
- indicesIdx = op.indices().size() - 1;
- resultIdx >= 0; --resultIdx, --indicesIdx) {
+ op.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
// Already marked unmasked, nothing to see here.
if (!op.isMaskedDim(resultIdx)) {
isMasked.push_back(false);
- continue;
+ return;
}
// Currently masked, check whether we can statically determine it is
// inBounds.
@@ -1768,12 +1739,11 @@ static LogicalResult foldTransferMaskAttribute(TransferOp op) {
isMasked.push_back(!inBounds);
// We commit the pattern if it is "more inbounds".
changed |= inBounds;
- }
+ });
if (!changed)
return failure();
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
- std::reverse(isMasked.begin(), isMasked.end());
op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked));
return success();
}
@@ -1842,8 +1812,7 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
- auto permMap =
- TransferWriteOp::getTransferMinorIdentityMap(memRefType, vectorType);
+ auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
@@ -1855,7 +1824,7 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
static void print(OpAsmPrinter &p, TransferWriteOp op) {
p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
<< op.indices() << "]";
- printTransferAttrs(p, op);
+ printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index a63862c1a4fe..ab93ef406024 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -30,7 +30,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index d5beaefc5eac..75ebb2f7d959 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -243,6 +243,18 @@ AffineMap mlir::makePermutationMap(
return ::makePermutationMap(indices, enclosingLoopToVectorDim);
}
+AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType,
+ VectorType vectorType) {
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ memRefType.getElementType().dyn_cast<VectorType>();
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
+ return AffineMap::getMinorIdentityMap(
+ memRefType.getRank(), vectorType.getRank() - elementVectorRank,
+ memRefType.getContext());
+}
+
bool matcher::operatesOnSuperVectorsOf(Operation &op,
VectorType subVectorType) {
// First, extract the vector type and distinguish between:
@@ -257,11 +269,8 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
bool mustDivide = false;
(void)mustDivide;
VectorType superVectorType;
- if (auto read = dyn_cast<vector::TransferReadOp>(op)) {
- superVectorType = read.getVectorType();
- mustDivide = true;
- } else if (auto write = dyn_cast<vector::TransferWriteOp>(op)) {
- superVectorType = write.getVectorType();
+ if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
+ superVectorType = transfer.getVectorType();
mustDivide = true;
} else if (op.getNumResults() == 0) {
if (!isa<ReturnOp>(op)) {
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index b8498e224f25..0a8f75b6f7d9 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -6,7 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
SideEffectInterfaces.cpp
- VectorUnrollInterface.cpp
+ VectorInterfaces.cpp
ViewLikeInterface.cpp
)
@@ -33,6 +33,6 @@ add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(LoopLikeInterface)
add_mlir_interface_library(SideEffectInterfaces)
-add_mlir_interface_library(VectorUnrollInterface)
+add_mlir_interface_library(VectorInterfaces)
add_mlir_interface_library(ViewLikeInterface)
diff --git a/mlir/lib/Interfaces/VectorUnrollInterface.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
similarity index 74%
rename from mlir/lib/Interfaces/VectorUnrollInterface.cpp
rename to mlir/lib/Interfaces/VectorInterfaces.cpp
index 6d3d432a7061..0f16b885ca2f 100644
--- a/mlir/lib/Interfaces/VectorUnrollInterface.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -1,4 +1,4 @@
-//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===//
+//===- VectorInterfaces.cpp - Unrollable vector operations -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
using namespace mlir;
@@ -15,4 +15,4 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
/// Include the definitions of the VectorUntoll interfaces.
-#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc"
+#include "mlir/Interfaces/VectorInterfaces.cpp.inc"
More information about the Mlir-commits
mailing list