[Mlir-commits] [mlir] d3588d0 - [mlir][NFC] Replace mlir/Support/Functional.h with llvm equivalents.

River Riddle llvmlistbot at llvm.org
Mon Apr 13 14:23:03 PDT 2020


Author: River Riddle
Date: 2020-04-13T14:22:12-07:00
New Revision: d3588d0814c4cbc7fca677b4d9634f6e1428a331

URL: https://github.com/llvm/llvm-project/commit/d3588d0814c4cbc7fca677b4d9634f6e1428a331
DIFF: https://github.com/llvm/llvm-project/commit/d3588d0814c4cbc7fca677b4d9634f6e1428a331.diff

LOG: [mlir][NFC] Replace mlir/Support/Functional.h with llvm equivalents.

Summary: Functional.h contains many different methods that have a direct, and more efficient, equivalent in LLVM. This revision replaces all usages with the LLVM equivalent, and removes the header. This is part of larger cleanup, pr45513, merging MLIR support facilities into LLVM.

Differential Revision: https://reviews.llvm.org/D78053

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Analysis/SliceAnalysis.cpp
    mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
    mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
    mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
    mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/Vector/EDSC/Builders.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/lib/IR/Builders.cpp
    mlir/test/EDSC/builder-api-test.cpp
    mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
    mlir/test/mlir-tblgen/predicate.td

Removed: 
    mlir/include/mlir/Support/Functional.h


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 56970f7a5aa7..56c958d16a8f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2064,9 +2064,10 @@ class TCOpIsBroadcastableToRes<int opId, int resId> : And<[
 // 1) all operands involved are of shaped type and
 // 2) the indices are not out of range.
 class TCopVTEtAreSameAt<list<int> indices> : CPred<
-  "llvm::is_splat(mlir::functional::map("
-    "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }, "
-    "llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">;
+  "llvm::is_splat(llvm::map_range("
+    "llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "}), "
+    "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); "
+    "}))">;
 
 //===----------------------------------------------------------------------===//
 // Pattern definitions

diff  --git a/mlir/include/mlir/Support/Functional.h b/mlir/include/mlir/Support/Functional.h
deleted file mode 100644
index 0950ab9372e7..000000000000
--- a/mlir/include/mlir/Support/Functional.h
+++ /dev/null
@@ -1,113 +0,0 @@
-//===- Functional.h - Helpers for functional-style Combinators --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_SUPPORT_FUNCTIONAL_H_
-#define MLIR_SUPPORT_FUNCTIONAL_H_
-
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
-
-/// This file provides some simple template functional-style sugar to operate
-/// on **value** types. Make sure when using that the stored type is cheap to
-/// copy!
-///
-/// TODO(ntv): add some static_assert but we need proper traits for this.
-
-namespace mlir {
-namespace functional {
-
-/// Map with iterators.
-template <typename Fn, typename IterType>
-auto map(Fn fun, IterType begin, IterType end)
-    -> SmallVector<typename std::result_of<Fn(decltype(*begin))>::type, 8> {
-  using R = typename std::result_of<Fn(decltype(*begin))>::type;
-  SmallVector<R, 8> res;
-  // auto i works with both pointer types and value types with an operator*.
-  // auto *i only works for pointer types.
-  for (auto i = begin; i != end; ++i) {
-    res.push_back(fun(*i));
-  }
-  return res;
-}
-
-/// Map with templated container.
-template <typename Fn, typename ContainerType>
-auto map(Fn fun, ContainerType input)
-    -> decltype(map(fun, std::begin(input), std::end(input))) {
-  return map(fun, std::begin(input), std::end(input));
-}
-
-/// Zip map with 2 templated container, iterates to the min of the sizes of
-/// the 2 containers.
-/// TODO(ntv): make variadic when needed.
-template <typename Fn, typename ContainerType1, typename ContainerType2>
-auto zipMap(Fn fun, ContainerType1 input1, ContainerType2 input2)
-    -> SmallVector<typename std::result_of<Fn(decltype(*input1.begin()),
-                                              decltype(*input2.begin()))>::type,
-                   8> {
-  using R = typename std::result_of<Fn(decltype(*input1.begin()),
-                                       decltype(*input2.begin()))>::type;
-  SmallVector<R, 8> res;
-  auto zipIter = llvm::zip(input1, input2);
-  for (auto it : zipIter) {
-    res.push_back(fun(std::get<0>(it), std::get<1>(it)));
-  }
-  return res;
-}
-
-/// Apply with iterators.
-template <typename Fn, typename IterType>
-void apply(Fn fun, IterType begin, IterType end) {
-  // auto i works with both pointer types and value types with an operator*.
-  // auto *i only works for pointer types.
-  for (auto i = begin; i != end; ++i) {
-    fun(*i);
-  }
-}
-
-/// Apply with templated container.
-template <typename Fn, typename ContainerType>
-void apply(Fn fun, ContainerType input) {
-  return apply(fun, std::begin(input), std::end(input));
-}
-
-/// Zip apply with 2 templated container, iterates to the min of the sizes of
-/// the 2 containers.
-/// TODO(ntv): make variadic when needed.
-template <typename Fn, typename ContainerType1, typename ContainerType2>
-void zipApply(Fn fun, ContainerType1 input1, ContainerType2 input2) {
-  auto zipIter = llvm::zip(input1, input2);
-  for (auto it : zipIter) {
-    fun(std::get<0>(it), std::get<1>(it));
-  }
-}
-
-/// Unwraps a pointer type to another type (possibly the same).
-/// Used in particular to allow easier compositions of
-///   Operation::operand_range types.
-template <typename T, typename ToType = T>
-inline std::function<ToType *(T *)> makePtrDynCaster() {
-  return [](T *val) { return dyn_cast<ToType>(val); };
-}
-
-/// Simple ScopeGuard.
-struct ScopeGuard {
-  explicit ScopeGuard(std::function<void(void)> destruct)
-      : destruct(destruct) {}
-  ~ScopeGuard() { destruct(); }
-
-private:
-  std::function<void(void)> destruct;
-};
-
-} // namespace functional
-} // namespace mlir
-
-#endif // MLIR_SUPPORT_FUNCTIONAL_H_

diff  --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index b1e45d1cfe7b..56fabcf28df9 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/SetVector.h"

diff  --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 54520b28c158..2fb99461647e 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -23,7 +23,6 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
 
@@ -222,13 +221,13 @@ Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder,
                                                       AffineMap affineMap,
                                                       ValueRange operands) {
   auto numDims = affineMap.getNumDims();
-  auto expanded = functional::map(
-      [numDims, &builder, loc, operands](AffineExpr expr) {
-        return expandAffineExpr(builder, loc, expr,
-                                operands.take_front(numDims),
-                                operands.drop_front(numDims));
-      },
-      affineMap.getResults());
+  auto expanded = llvm::to_vector<8>(
+      llvm::map_range(affineMap.getResults(),
+                      [numDims, &builder, loc, operands](AffineExpr expr) {
+                        return expandAffineExpr(builder, loc, expr,
+                                                operands.take_front(numDims),
+                                                operands.drop_front(numDims));
+                      }));
   if (llvm::all_of(expanded, [](Value v) { return v; }))
     return expanded;
   return None;

diff  --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
index bb00639ed964..fbcad36c7bd5 100644
--- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
+++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp
@@ -20,7 +20,6 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 2f3fdd01059f..acc84f9e9a46 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -23,7 +23,6 @@
 #include "mlir/IR/Module.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 620b200333ee..90f098662086 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -25,7 +25,6 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Types.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/FoldUtils.h"
 
@@ -525,8 +524,6 @@ using namespace mlir;
 
 #define DEBUG_TYPE "early-vect"
 
-using functional::makePtrDynCaster;
-using functional::map;
 using llvm::dbgs;
 using llvm::SetVector;
 
@@ -812,7 +809,6 @@ static LogicalResult vectorizeRootOrTerminal(Value iv,
 /// operations into the appropriate vector.transfer.
 static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step,
                                           VectorizationState *state) {
-  using namespace functional;
   loop.setStep(step);
 
   FilterFunctionType notVectorizedThisPattern = [state](Operation &op) {

diff  --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
index b24a925aaad3..24dd018e0c44 100644
--- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/LoopOps/EDSC/Builders.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Transforms/LoopUtils.h"
 
 using namespace mlir;

diff  --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 10c18107fd8e..59a565c5e395 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/Support/Functional.h"
 
 using namespace mlir;
 using namespace mlir::edsc;
@@ -164,7 +163,8 @@ Operation *mlir::edsc::makeGenericLinalgOp(
   std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(types),
                [](StructuredIndexed s) { return !s.hasValue(); });
 
-  auto iteratorStrTypes = functional::map(toString, iteratorTypes);
+  auto iteratorStrTypes =
+      llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
   // clang-format off
   auto *op =
       edsc::ScopedContext::getBuilder()

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3d81cce0e883..24decc760525 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
 
@@ -500,8 +499,8 @@ computeReshapeCollapsedType(MemRefType type,
 /// TODO(rridle,ntv) this should be evolved into a generic
 /// `getRangeOfType<AffineMap>(ArrayAttr attrs)` that does not copy.
 static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
-  return functional::map(
-      [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }, attrs);
+  return llvm::to_vector<8>(llvm::map_range(
+      attrs, [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
 }
 
 template <typename AffineExprTy>

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 6be0bd8ea204..7c4389341349 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -19,7 +19,6 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -113,8 +112,8 @@ getInputAndOutputIndices(ArrayRef<Value> allIvs, SingleInputPoolingOp op) {
   auto &b = ScopedContext::getBuilder();
   auto loc = ScopedContext::getLocation();
   auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
-  auto maps =
-      functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
+  auto maps = llvm::to_vector<8>(
+      llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
   SmallVector<ValueHandle, 8> iIdx(
       makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
   SmallVector<ValueHandle, 8> oIdx(
@@ -273,8 +272,8 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
     auto b = ScopedContext::getBuilder();
     auto loc = ScopedContext::getLocation();
     auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>();
-    auto maps = functional::map([](AffineMapAttr a) { return a.getValue(); },
-                                mapsRange);
+    auto maps = llvm::to_vector<8>(llvm::map_range(
+        mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
     SmallVector<ValueHandle, 8> fIdx(
         makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
     SmallVector<ValueHandle, 8> imIdx(
@@ -650,8 +649,8 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
   auto nLoops = nPar + nRed + nWin;
   auto mapsRange =
       linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
-  auto maps =
-      functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
+  auto maps = llvm::to_vector<8>(
+      llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
   AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
   if (invertedMap.isEmpty()) {
     LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 997895b6a869..9026ccfd4af9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -359,8 +358,8 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (asserted in the inverse calculation).
   auto mapsRange = op.indexing_maps().getAsRange<AffineMapAttr>();
-  auto maps =
-      functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
+  auto maps = llvm::to_vector<8>(
+      llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
   auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps));
   assert(viewSizesToLoopsMap && "expected invertible map");
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index 953d95b449d1..efd15d773f8f 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/Functional.h"
 
 using namespace mlir;
 
@@ -131,11 +130,10 @@ void spirv::BitcastOp::getCanonicalizationPatterns(
 
 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
-  auto indexVector = functional::map(
-      [](Attribute attr) {
+  auto indexVector =
+      llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
         return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
-      },
-      indices());
+      }));
   return extractCompositeElement(operands[0], indexVector);
 }
 

diff  --git a/mlir/lib/Dialect/Vector/EDSC/Builders.cpp b/mlir/lib/Dialect/Vector/EDSC/Builders.cpp
index d2436ef8ae62..0759f93edaec 100644
--- a/mlir/lib/Dialect/Vector/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Vector/EDSC/Builders.cpp
@@ -13,7 +13,6 @@
 #include "mlir/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/Support/Functional.h"
 
 using namespace mlir;
 using namespace mlir::edsc;
@@ -27,7 +26,8 @@ Value mlir::edsc::ops::vector_contraction(
   return vector_contract(
       A.getValue(), B.getValue(), C.getValue(),
       IndexingExprs{A.getExprs(), B.getExprs(), C.getExprs()},
-      ArrayRef<StringRef>{functional::map(toString, iteratorTypes)});
+      ArrayRef<StringRef>{
+          llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString))});
 }
 
 Value mlir::edsc::ops::vector_contraction_matmul(Value A, Value B, Value C) {

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 53eee4a6bc62..150370f1ad3e 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Support/STLExtras.h"
@@ -893,12 +892,10 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
 
 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
                                   MLIRContext *context) {
-  auto attrs = functional::map(
-      [context](int64_t v) -> Attribute {
-        return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
-      },
-      values);
-  return ArrayAttr::get(attrs, context);
+  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
+    return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
+  });
+  return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
 }
 
 static LogicalResult verify(InsertStridedSliceOp op) {

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 3cd0b7b4b733..016897e7f891 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -29,7 +29,6 @@
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Types.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/STLExtras.h"
 
 #include "llvm/Support/CommandLine.h"
@@ -40,7 +39,6 @@
 
 using namespace mlir;
 using llvm::dbgs;
-using mlir::functional::zipMap;
 
 /// Given a shape with sizes greater than 0 along all dimensions,
 /// returns the distance, in number of elements, between a slice in a dimension
@@ -774,19 +772,15 @@ static Value getProducerValue(Value consumerValue) {
       int i = sourceVectorRank - 1;
       int j = resultVectorRank - 1;
 
-      // Check that source/result vector shape prefixes match while
-      // updating 'newOffsets'.
-      bool canShapeCastFold = true;
+      // Check that source/result vector shape prefixes match while updating
+      // 'newOffsets'.
       SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
-
-      auto apply = [&](int64_t sourceSize, int64_t resultSize) {
-        canShapeCastFold = sourceSize == resultSize;
+      for (auto it : llvm::zip(llvm::reverse(sourceVectorShape),
+                               llvm::reverse(resultVectorShape))) {
+        if (std::get<0>(it) != std::get<1>(it))
+          return nullptr;
         newOffsets[i--] = offsets[j--];
-      };
-      functional::zipApply(apply, llvm::reverse(sourceVectorShape),
-                           llvm::reverse(resultVectorShape));
-      if (!canShapeCastFold)
-        return nullptr;
+      }
 
       // Check that remaining prefix of source/result vector shapes are all 1s.
       // Currently we only support producer/consumer tracking through trivial

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 9038b7ad6617..2c47da741ea7 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -18,7 +18,6 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Support/STLExtras.h"
@@ -67,8 +66,10 @@ SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
 
 SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
     ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
-  return functional::zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
-                            vectorOffsets, sizes);
+  SmallVector<int64_t, 4> result;
+  for (auto it : llvm::zip(vectorOffsets, sizes))
+    result.push_back(std::get<0>(it) * std::get<1>(it));
+  return result;
 }
 
 SmallVector<int64_t, 4>
@@ -88,23 +89,19 @@ Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
   }
 
   // Starting from the end, compute the integer divisors.
-  // Set the boolean `divides` if integral division is not possible.
   std::vector<int64_t> result;
   result.reserve(superShape.size());
-  bool divides = true;
-  auto divide = [&divides, &result](int superSize, int subSize) {
+  int64_t superSize = 0, subSize = 0;
+  for (auto it :
+       llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) {
+    std::tie(superSize, subSize) = it;
     assert(superSize > 0 && "superSize must be > 0");
     assert(subSize > 0 && "subSize must be > 0");
-    divides &= (superSize % subSize == 0);
-    result.push_back(superSize / subSize);
-  };
-  functional::zipApply(
-      divide, SmallVector<int64_t, 8>{superShape.rbegin(), superShape.rend()},
-      SmallVector<int64_t, 8>{subShape.rbegin(), subShape.rend()});
 
-  // If integral division does not occur, return and let the caller decide.
-  if (!divides) {
-    return None;
+    // If integral division does not occur, return and let the caller decide.
+    if (superSize % subSize != 0)
+      return None;
+    result.push_back(superSize / subSize);
   }
 
   // At this point we computed the ratio (in reverse) for the common
@@ -157,8 +154,6 @@ static AffineMap makePermutationMap(
     return AffineMap();
   MLIRContext *context =
       enclosingLoopToVectorDim.begin()->getFirst()->getContext();
-  using functional::makePtrDynCaster;
-  using functional::map;
   SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
                                   getAffineConstantExpr(0, context));
 

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 7e5036e37bab..be62a164cc8d 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -10,7 +10,6 @@
 #include "AffineMapDetail.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/StringRef.h"

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 84a883c64d24..40954a69f58f 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -14,7 +14,6 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/Functional.h"
 #include "llvm/Support/raw_ostream.h"
 using namespace mlir;
 
@@ -204,47 +203,46 @@ Builder::getSymbolRefAttr(StringRef value,
 }
 
 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
-  auto attrs = functional::map(
-      [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values);
+  auto attrs = llvm::to_vector<8>(llvm::map_range(
+      values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
-  auto attrs = functional::map(
-      [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values);
+  auto attrs = llvm::to_vector<8>(llvm::map_range(
+      values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
-  auto attrs = functional::map(
-      [this](int64_t v) -> Attribute {
+  auto attrs = llvm::to_vector<8>(
+      llvm::map_range(values, [this](int64_t v) -> Attribute {
         return getIntegerAttr(IndexType::get(getContext()), v);
-      },
-      values);
+      }));
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
-  auto attrs = functional::map(
-      [this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
+  auto attrs = llvm::to_vector<8>(llvm::map_range(
+      values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
-  auto attrs = functional::map(
-      [this](double v) -> Attribute { return getF64FloatAttr(v); }, values);
+  auto attrs = llvm::to_vector<8>(llvm::map_range(
+      values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
-  auto attrs = functional::map(
-      [this](StringRef v) -> Attribute { return getStringAttr(v); }, values);
+  auto attrs = llvm::to_vector<8>(llvm::map_range(
+      values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
-  auto attrs = functional::map(
-      [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }, values);
+  auto attrs = llvm::to_vector<8>(llvm::map_range(
+      values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
   return getArrayAttr(attrs);
 }
 

diff  --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 0c725e98fa3b..dd90511a0472 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -24,7 +24,6 @@
 #include "mlir/IR/Types.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Transforms/LoopUtils.h"
 #include "mlir/Transforms/Passes.h"
 

diff  --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index ca738fde6103..4df97bca7724 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -19,7 +19,6 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/STLExtras.h"
 #include "mlir/Transforms/Passes.h"
 
@@ -33,8 +32,6 @@ using namespace mlir;
 
 using llvm::SetVector;
 
-using functional::map;
-
 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
 
 static llvm::cl::list<int> clTestVectorShapeRatio(
@@ -129,7 +126,6 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
 }
 
 static NestedPattern patternTestSlicingOps() {
-  using functional::map;
   using matcher::Op;
   // Match all operations with the kTestSlicingOpName name.
   auto filter = [](Operation &op) {

diff  --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 829445f7a148..aa7b50710cde 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -81,9 +81,9 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
 }
 
 // CHECK-LABEL: OpJ::verify()
-// CHECK:      llvm::is_splat(mlir::functional::map(
-// CHECK-SAME:   [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); },
-// CHECK-SAME:   llvm::ArrayRef<unsigned>({0, 2, 3})))
+// CHECK:      llvm::is_splat(llvm::map_range(
+// CHECK-SAME:   llvm::ArrayRef<unsigned>({0, 2, 3}),
+// CHECK-SAME:   [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
 // CHECK:   return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");
 
 def OpK : NS_Op<"op_for_AnyTensorOf", []> {


        


More information about the Mlir-commits mailing list