[Mlir-commits] [mlir] 47cbd9f - [mlir][Vector] NFC - Improve VectorInterfaces

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 20 05:25:14 PDT 2020


Author: Nicolas Vasilache
Date: 2020-07-20T08:24:22-04:00
New Revision: 47cbd9f92282e3a19f161053cfbf77a7691de43e

URL: https://github.com/llvm/llvm-project/commit/47cbd9f92282e3a19f161053cfbf77a7691de43e
DIFF: https://github.com/llvm/llvm-project/commit/47cbd9f92282e3a19f161053cfbf77a7691de43e.diff

LOG: [mlir][Vector] NFC - Improve VectorInterfaces

This revision improves and makes better use of OpInterfaces for the Vector dialect.

Differential Revision: https://reviews.llvm.org/D84053

Added: 
    mlir/include/mlir/Interfaces/VectorInterfaces.h
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Interfaces/VectorInterfaces.cpp

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Dialect/StandardOps/CMakeLists.txt
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/lib/Interfaces/CMakeLists.txt

Removed: 
    mlir/include/mlir/Interfaces/VectorUnrollInterface.h
    mlir/include/mlir/Interfaces/VectorUnrollInterface.td
    mlir/lib/Interfaces/VectorUnrollInterface.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 0f24d74dcac2..2500343c0af3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -21,7 +21,7 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
 // Pull in all enum type definitions and utility function declarations.

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 702b912d3103..78307b897476 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -18,7 +18,7 @@ include "mlir/IR/OpAsmInterface.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/VectorUnrollInterface.td"
+include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 
 def StandardOps_Dialect : Dialect {

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 0f6aa66e926c..edf9557df389 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -19,7 +19,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
 
 namespace mlir {
 class MLIRContext;

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 8880c288b648..10a4498b0bbd 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -15,7 +15,7 @@
 
 include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/VectorUnrollInterface.td"
+include "mlir/Interfaces/VectorInterfaces.td"
 
 def Vector_Dialect : Dialect {
   let name = "vector";
@@ -905,34 +905,9 @@ def Vector_ExtractStridedSliceOp :
   let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
 }
 
-def Vector_TransferOpUtils {
-  code extraTransferDeclaration = [{
-    static StringRef getMaskedAttrName() { return "masked"; }
-    static StringRef getPermutationMapAttrName() { return "permutation_map"; }
-    bool isMaskedDim(unsigned dim) {
-      return !masked() ||
-        masked()->cast<ArrayAttr>()[dim].cast<BoolAttr>().getValue();
-    }
-    MemRefType getMemRefType() {
-      return memref().getType().cast<MemRefType>();
-    }
-    VectorType getVectorType() {
-      return vector().getType().cast<VectorType>();
-    }
-    // Number of dimensions that participate in the permutation map.
-    unsigned getTransferRank() {
-      return permutation_map().getNumResults();
-    }
-    // Number of leading dimensions that do not participate in the permutation
-    // map.
-    unsigned getLeadingMemRefRank() {
-      return getMemRefType().getRank() - permutation_map().getNumResults();
-    }
-  }];
-}
-
 def Vector_TransferReadOp :
   Vector_Op<"transfer_read", [
+      DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
     Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
@@ -1090,23 +1065,12 @@ def Vector_TransferReadOp :
               "ArrayRef<bool> maybeMasked = {}">
   ];
 
-  let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
-  [{
-    /// Build the default minor identity map suitable for a vector transfer.
-    /// This also handles the case memref<... x vector<...>> -> vector<...> in
-    /// which the rank of the identity map must take the vector element type
-    /// into account.
-    static AffineMap getTransferMinorIdentityMap(
-      MemRefType memRefType, VectorType vectorType) {
-        return impl::getTransferMinorIdentityMap(memRefType, vectorType);
-      }
-  }];
-
   let hasFolder = 1;
 }
 
 def Vector_TransferWriteOp :
   Vector_Op<"transfer_write", [
+      DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
     Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
@@ -1183,18 +1147,6 @@ def Vector_TransferWriteOp :
               "Value memref, ValueRange indices, AffineMap permutationMap">,
   ];
 
-  let extraClassDeclaration = Vector_TransferOpUtils.extraTransferDeclaration #
-  [{
-    /// Build the default minor identity map suitable for a vector transfer.
-    /// This also handles the case memref<... x vector<...>> -> vector<...> in
-    /// which the rank of the identity map must take the vector element type
-    /// into account.
-    static AffineMap getTransferMinorIdentityMap(
-      MemRefType memRefType, VectorType vectorType) {
-        return impl::getTransferMinorIdentityMap(memRefType, vectorType);
-      }
-  }];
-
   let hasFolder = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 19f7f9538307..448004db32fa 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -153,6 +153,12 @@ AffineMap
 makePermutationMap(Operation *op, ArrayRef<Value> indices,
                    const DenseMap<Operation *, unsigned> &loopToVectorDim);
 
+/// Build the default minor identity map suitable for a vector transfer. This
+/// also handles the case memref<... x vector<...>> -> vector<...> in which the
+/// rank of the identity map must take the vector element type into account.
+AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
+                                      VectorType vectorType);
+
 namespace matcher {
 
 /// Matches vector.transfer_read, vector.transfer_write and ops that return a

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 0de2b5a8688b..65e19f3eec1b 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,6 @@ add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(SideEffectInterfaces)
-add_mlir_interface(VectorUnrollInterface)
+add_mlir_interface(VectorInterfaces)
 add_mlir_interface(ViewLikeInterface)
 

diff  --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorInterfaces.h
similarity index 59%
rename from mlir/include/mlir/Interfaces/VectorUnrollInterface.h
rename to mlir/include/mlir/Interfaces/VectorInterfaces.h
index a68cc3411533..2134969e4020 100644
--- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.h
@@ -1,4 +1,4 @@
-//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===//
+//===- VectorInterfaces.h - Vector interfaces -----------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,18 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements the operation interface for vector ops that can be
-// unrolled.
+// This file implements the operation interfaces for vector ops.
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
-#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+#ifndef MLIR_INTERFACES_VECTORINTERFACES_H
+#define MLIR_INTERFACES_VECTORINTERFACES_H
 
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 
 /// Include the generated interface declarations.
-#include "mlir/Interfaces/VectorUnrollInterface.h.inc"
+#include "mlir/Interfaces/VectorInterfaces.h.inc"
 
-#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+#endif // MLIR_INTERFACES_VECTORINTERFACES_H

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
new file mode 100644
index 000000000000..aefbb7d47117
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -0,0 +1,194 @@
+//===- VectorInterfaces.td - Vector interfaces -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for operations on vectors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_VECTORINTERFACES
+#define MLIR_INTERFACES_VECTORINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
+  let description = [{
+    Encodes properties of an operation on vectors that can be unrolled.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the shape ratio of unrolling to the target vector shape
+        `targetShape`. Return `None` if the op cannot be unrolled to the target
+        vector shape.
+      }],
+      /*retTy=*/"Optional<SmallVector<int64_t, 4>>",
+      /*methodName=*/"getShapeForUnroll",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert($_op.getOperation()->getNumResults() == 1);
+        auto vt = $_op.getResult().getType().
+          template dyn_cast<VectorType>();
+        if (!vt)
+          return None;
+        SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
+        return res;
+      }]
+    >,
+  ];
+}
+
+def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
+  let description = [{
+    Encodes properties of an operation on vectors that can be unrolled.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    StaticInterfaceMethod<
+      /*desc=*/"Return the `masked` attribute name.",
+      /*retTy=*/"StringRef",
+      /*methodName=*/"getMaskedAttrName",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/ [{ return "masked"; }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/"Return the `permutation_map` attribute name.",
+      /*retTy=*/"StringRef",
+      /*methodName=*/"getPermutationMapAttrName",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/ [{ return "permutation_map"; }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+      Return `false` when the `masked` attribute at dimension
+      `dim` is set to `false`. Return `true` otherwise.}],
+      /*retTy=*/"bool",
+      /*methodName=*/"isMaskedDim",
+      /*args=*/(ins "unsigned":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return !$_op.masked() ||
+          $_op.masked()->template cast<ArrayAttr>()[dim]
+                        .template cast<BoolAttr>().getValue();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the memref operand.",
+      /*retTy=*/"Value",
+      /*methodName=*/"memref",
+      /*args=*/(ins),
+      /*methodBody=*/"return $_op.memref();"
+      /*defaultImplementation=*/
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the vector operand or result.",
+      /*retTy=*/"Value",
+      /*methodName=*/"vector",
+      /*args=*/(ins),
+      /*methodBody=*/"return $_op.vector();"
+      /*defaultImplementation=*/
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the indices operands.",
+      /*retTy=*/"ValueRange",
+      /*methodName=*/"indices",
+      /*args=*/(ins),
+      /*methodBody=*/"return $_op.indices();"
+      /*defaultImplementation=*/
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the permutation map.",
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"permutation_map",
+      /*args=*/(ins),
+      /*methodBody=*/"return $_op.permutation_map();"
+      /*defaultImplementation=*/
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the `masked` boolean ArrayAttr.",
+      /*retTy=*/"Optional<ArrayAttr>",
+      /*methodName=*/"masked",
+      /*args=*/(ins),
+      /*methodBody=*/"return $_op.masked();"
+      /*defaultImplementation=*/
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the MemRefType.",
+      /*retTy=*/"MemRefType",
+      /*methodName=*/"getMemRefType",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/
+        "return $_op.memref().getType().template cast<MemRefType>();"
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the VectorType.",
+      /*retTy=*/"VectorType",
+      /*methodName=*/"getVectorType",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/
+        "return $_op.vector().getType().template cast<VectorType>();"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Return the number of dimensions that participate in the
+                  permutation map.}],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getTransferRank",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/
+        "return $_op.permutation_map().getNumResults();"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Return the number of leading memref dimensions that do not
+                  participate in the permutation map.}],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getLeadingMemRefRank",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/
+        "return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+      Helper function to account for the fact that `permutationMap` results and
+      `op.indices` sizes may not match and may not be aligned. The first
+      `getLeadingMemRefRank()` indices may just be indexed and not transferred
+      from/into the vector.
+      For example:
+      ```
+         vector.transfer %0[%i, %j, %k, %c0] :
+           memref<?x?x?x?xf32>, vector<2x4xf32>
+      ```
+      with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`.
+      Provide a zip function to coiterate on 2 running indices: `resultIdx` and
+      `indicesIdx` which accounts for this misalignment.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"zipResultAndIndexing",
+      /*args=*/(ins "llvm::function_ref<void(int64_t, int64_t)>":$fun),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        for (int64_t resultIdx = 0,
+                   indicesIdx = $_op.getLeadingMemRefRank(),
+                   eResult = $_op.getTransferRank();
+           resultIdx < eResult;
+           ++resultIdx, ++indicesIdx)
+        fun(resultIdx, indicesIdx);
+      }]
+    >,
+  ];
+}
+
+#endif // MLIR_INTERFACES_VECTORINTERFACES

diff  --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
deleted file mode 100644
index 166780b20e77..000000000000
--- a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
+++ /dev/null
@@ -1,46 +0,0 @@
-//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines the interface for operations on vectors that can be unrolled.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE
-#define MLIR_INTERFACES_VECTORUNROLLINTERFACE
-
-include "mlir/IR/OpBase.td"
-
-def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
-  let description = [{
-    Encodes properties of an operation on vectors that can be unrolled.
-  }];
-  let cppNamespace = "::mlir";
-
-  let methods = [
-    InterfaceMethod<[{
-        Returns the shape ratio of unrolling to the target vector shape
-        `targetShape`. Returns `None` if the op cannot be unrolled to the target
-        vector shape.
-      }],
-      "Optional<SmallVector<int64_t, 4>>",
-      "getShapeForUnroll",
-      (ins),
-      /*methodBody=*/[{}],
-      [{
-        auto vt = this->getOperation()->getResult(0).getType().
-          template dyn_cast<VectorType>();
-        if (!vt)
-          return None;
-        SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
-        return res;
-      }]
-    >,
-  ];
-}
-
-#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index d0529668b2ee..ea368c9eb14e 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -249,8 +249,8 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
       indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
       indexing.append(minorOffsets.begin(), minorOffsets.end());
       Value memref = xferOp.memref();
-      auto map = TransferReadOp::getTransferMinorIdentityMap(
-          xferOp.getMemRefType(), minorVectorType);
+      auto map =
+          getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
       ArrayAttr masked;
       if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
         OpBuilder &b = ScopedContext::getBuilderRef();
@@ -353,8 +353,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
         result = vector_extract(xferOp.vector(), majorIvs);
       else
         result = std_load(alloc, majorIvs);
-      auto map = TransferWriteOp::getTransferMinorIdentityMap(
-          xferOp.getMemRefType(), minorVectorType);
+      auto map =
+          getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
       ArrayAttr masked;
       if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
         OpBuilder &b = ScopedContext::getBuilderRef();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index cb7540b46cf8..180fe069b681 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -82,8 +82,8 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
 
 /// Return true if we can prove that the transfer operations access dijoint
 /// memory.
-template <typename TransferTypeA, typename TransferTypeB>
-static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
+static bool isDisjoint(VectorTransferOpInterface transferA,
+                       VectorTransferOpInterface transferB) {
   if (transferA.memref() != transferB.memref())
     return false;
   // For simplicity only look at transfer of same type.
@@ -91,8 +91,8 @@ static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
     return false;
   unsigned rankOffset = transferA.getLeadingMemRefRank();
   for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
-    auto indexA = transferA.indices()[i].template getDefiningOp<ConstantOp>();
-    auto indexB = transferB.indices()[i].template getDefiningOp<ConstantOp>();
+    auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
+    auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();
     // If any of the indices are dynamic we cannot prove anything.
     if (!indexA || !indexB)
       continue;
@@ -100,15 +100,15 @@ static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
     if (i < rankOffset) {
       // For dimension used as index if we can prove that index are 
diff erent we
       // know we are accessing disjoint slices.
-      if (indexA.getValue().template cast<IntegerAttr>().getInt() !=
-          indexB.getValue().template cast<IntegerAttr>().getInt())
+      if (indexA.getValue().cast<IntegerAttr>().getInt() !=
+          indexB.getValue().cast<IntegerAttr>().getInt())
         return true;
     } else {
       // For this dimension, we slice a part of the memref we need to make sure
       // the intervals accessed don't overlap.
       int64_t distance =
-          std::abs(indexA.getValue().template cast<IntegerAttr>().getInt() -
-                   indexB.getValue().template cast<IntegerAttr>().getInt());
+          std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
+                   indexB.getValue().cast<IntegerAttr>().getInt());
       if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
         return true;
     }
@@ -185,11 +185,17 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
           continue;
         if (auto transferWriteUse =
                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
-          if (!isDisjoint(transferWrite, transferWriteUse))
+          if (!isDisjoint(
+                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
+                  cast<VectorTransferOpInterface>(
+                      transferWriteUse.getOperation())))
             return WalkResult::advance();
         } else if (auto transferReadUse =
                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
-          if (!isDisjoint(transferWrite, transferReadUse))
+          if (!isDisjoint(
+                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
+                  cast<VectorTransferOpInterface>(
+                      transferReadUse.getOperation())))
             return WalkResult::advance();
         } else {
           // Unknown use, we cannot prove that it doesn't alias with the

diff  --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index 7d61aea3116e..06284f5d1daa 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -15,7 +15,7 @@ add_mlir_dialect_library(MLIRStandardOps
   MLIREDSC
   MLIRIR
   MLIRSideEffectInterfaces
-  MLIRVectorUnrollInterface
+  MLIRVectorInterfaces
   MLIRViewLikeInterface
   )
 

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 69a329917228..d6ba987e6622 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -19,5 +19,5 @@ add_mlir_dialect_library(MLIRVector
   MLIRSCF
   MLIRLoopAnalysis
   MLIRSideEffectInterfaces
-  MLIRVectorUnrollInterface
+  MLIRVectorInterfaces
   )

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5e01fa26f32e..03c4079ef171 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1466,22 +1466,6 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
 // TransferReadOp
 //===----------------------------------------------------------------------===//
 
-/// Build the default minor identity map suitable for a vector transfer. This
-/// also handles the case memref<... x vector<...>> -> vector<...> in which the
-/// rank of the identity map must take the vector element type into account.
-AffineMap
-mlir::vector::impl::getTransferMinorIdentityMap(MemRefType memRefType,
-                                                VectorType vectorType) {
-  int64_t elementVectorRank = 0;
-  VectorType elementVectorType =
-      memRefType.getElementType().dyn_cast<VectorType>();
-  if (elementVectorType)
-    elementVectorRank += elementVectorType.getRank();
-  return AffineMap::getMinorIdentityMap(
-      memRefType.getRank(), vectorType.getRank() - elementVectorRank,
-      memRefType.getContext());
-}
-
 template <typename EmitFun>
 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
                                           EmitFun emitOpError) {
@@ -1600,11 +1584,10 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
 }
 
-template <typename TransferOp>
-static void printTransferAttrs(OpAsmPrinter &p, TransferOp op) {
+static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   SmallVector<StringRef, 2> elidedAttrs;
-  if (op.permutation_map() == TransferOp::getTransferMinorIdentityMap(
-                                  op.getMemRefType(), op.getVectorType()))
+  if (op.permutation_map() ==
+      getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
     elidedAttrs.push_back(op.getPermutationMapAttrName());
   bool elideMasked = true;
   if (auto maybeMasked = op.masked()) {
@@ -1623,7 +1606,7 @@ static void printTransferAttrs(OpAsmPrinter &p, TransferOp op) {
 static void print(OpAsmPrinter &p, TransferReadOp op) {
   p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
     << "], " << op.padding();
-  printTransferAttrs(p, op);
+  printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
   p << " : " << op.getMemRefType() << ", " << op.getVectorType();
 }
 
@@ -1653,8 +1636,7 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
   auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
   auto attr = result.attributes.get(permutationAttrName);
   if (!attr) {
-    auto permMap =
-        TransferReadOp::getTransferMinorIdentityMap(memRefType, vectorType);
+    auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
   }
   return failure(
@@ -1733,6 +1715,7 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
 
   int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
   int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
+
   return cstOp.getValue() + vectorSize <= memrefSize;
 }
 
@@ -1744,23 +1727,11 @@ static LogicalResult foldTransferMaskAttribute(TransferOp op) {
   bool changed = false;
   SmallVector<bool, 4> isMasked;
   isMasked.reserve(op.getTransferRank());
-  // `permutationMap` results and `op.indices` sizes may not match and may not
-  // be aligned. The first `indicesIdx` may just be indexed and not transferred
-  // from/into the vector.
-  // For example:
-  //  vector.transfer %0[%i, %j, %k, %c0] : memref<?x?x?x?xf32>, vector<2x4xf32>
-  // with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`.
-  // The `permutationMap` results and `op.indices` are however aligned when
-  // iterating in reverse until we exhaust `permutationMap` results.
-  // As a consequence we iterate with 2 running indices: `resultIdx` and
-  // `indicesIdx`, until `resultIdx` reaches 0.
-  for (int64_t resultIdx = permutationMap.getNumResults() - 1,
-               indicesIdx = op.indices().size() - 1;
-       resultIdx >= 0; --resultIdx, --indicesIdx) {
+  op.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
     // Already marked unmasked, nothing to see here.
     if (!op.isMaskedDim(resultIdx)) {
       isMasked.push_back(false);
-      continue;
+      return;
     }
     // Currently masked, check whether we can statically determine it is
     // inBounds.
@@ -1768,12 +1739,11 @@ static LogicalResult foldTransferMaskAttribute(TransferOp op) {
     isMasked.push_back(!inBounds);
     // We commit the pattern if it is "more inbounds".
     changed |= inBounds;
-  }
+  });
   if (!changed)
     return failure();
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(op.getContext());
-  std::reverse(isMasked.begin(), isMasked.end());
   op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked));
   return success();
 }
@@ -1842,8 +1812,7 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
   auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
   auto attr = result.attributes.get(permutationAttrName);
   if (!attr) {
-    auto permMap =
-        TransferWriteOp::getTransferMinorIdentityMap(memRefType, vectorType);
+    auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
   }
   return failure(
@@ -1855,7 +1824,7 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
 static void print(OpAsmPrinter &p, TransferWriteOp op) {
   p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
     << op.indices() << "]";
-  printTransferAttrs(p, op);
+  printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
   p << " : " << op.getVectorType() << ", " << op.getMemRefType();
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index a63862c1a4fe..ab93ef406024 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -30,7 +30,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
 
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index d5beaefc5eac..75ebb2f7d959 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -243,6 +243,18 @@ AffineMap mlir::makePermutationMap(
   return ::makePermutationMap(indices, enclosingLoopToVectorDim);
 }
 
+AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType,
+                                            VectorType vectorType) {
+  int64_t elementVectorRank = 0;
+  VectorType elementVectorType =
+      memRefType.getElementType().dyn_cast<VectorType>();
+  if (elementVectorType)
+    elementVectorRank += elementVectorType.getRank();
+  return AffineMap::getMinorIdentityMap(
+      memRefType.getRank(), vectorType.getRank() - elementVectorRank,
+      memRefType.getContext());
+}
+
 bool matcher::operatesOnSuperVectorsOf(Operation &op,
                                        VectorType subVectorType) {
   // First, extract the vector type and distinguish between:
@@ -257,11 +269,8 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
   bool mustDivide = false;
   (void)mustDivide;
   VectorType superVectorType;
-  if (auto read = dyn_cast<vector::TransferReadOp>(op)) {
-    superVectorType = read.getVectorType();
-    mustDivide = true;
-  } else if (auto write = dyn_cast<vector::TransferWriteOp>(op)) {
-    superVectorType = write.getVectorType();
+  if (auto transfer = dyn_cast<VectorTransferOpInterface>(op)) {
+    superVectorType = transfer.getVectorType();
     mustDivide = true;
   } else if (op.getNumResults() == 0) {
     if (!isa<ReturnOp>(op)) {

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index b8498e224f25..0a8f75b6f7d9 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -6,7 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
   SideEffectInterfaces.cpp
-  VectorUnrollInterface.cpp
+  VectorInterfaces.cpp
   ViewLikeInterface.cpp
   )
 
@@ -33,6 +33,6 @@ add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 add_mlir_interface_library(LoopLikeInterface)
 add_mlir_interface_library(SideEffectInterfaces)
-add_mlir_interface_library(VectorUnrollInterface)
+add_mlir_interface_library(VectorInterfaces)
 add_mlir_interface_library(ViewLikeInterface)
 

diff  --git a/mlir/lib/Interfaces/VectorUnrollInterface.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
similarity index 74%
rename from mlir/lib/Interfaces/VectorUnrollInterface.cpp
rename to mlir/lib/Interfaces/VectorInterfaces.cpp
index 6d3d432a7061..0f16b885ca2f 100644
--- a/mlir/lib/Interfaces/VectorUnrollInterface.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -1,4 +1,4 @@
-//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===//
+//===- VectorInterfaces.cpp - Unrollable vector operations -*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Interfaces/VectorUnrollInterface.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
 
 using namespace mlir;
 
@@ -15,4 +15,4 @@ using namespace mlir;
 //===----------------------------------------------------------------------===//
 
 /// Include the definitions of the VectorUntoll interfaces.
-#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc"
+#include "mlir/Interfaces/VectorInterfaces.cpp.inc"


        


More information about the Mlir-commits mailing list