[Mlir-commits] [mlir] d4fbf83 - [mlir][EDSC] NFC - Move StructuredIndexed and IteratorType out of Linalg

Nicolas Vasilache llvmlistbot at llvm.org
Sat Feb 8 10:42:49 PST 2020


Author: Nicolas Vasilache
Date: 2020-02-08T13:42:28-05:00
New Revision: d4fbf8312b966b669bc52b33bf9cf30648883921

URL: https://github.com/llvm/llvm-project/commit/d4fbf8312b966b669bc52b33bf9cf30648883921
DIFF: https://github.com/llvm/llvm-project/commit/d4fbf8312b966b669bc52b33bf9cf30648883921.diff

LOG: [mlir][EDSC] NFC - Move StructuredIndexed and IteratorType out of Linalg

Summary:
This NFC revision will allow those classes to be reused to allow
building structured vector operations.

Reviewers: aartbik, ftynse

Subscribers: arphaman, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74279

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/include/mlir/EDSC/Builders.h
    mlir/lib/Dialect/Linalg/EDSC/Builders.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index 21345334641d..cff93f13cd35 100644
--- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -87,57 +87,6 @@ template <typename LoopTy> class GenericLoopNestRangeBuilder {
   std::unique_ptr<BuilderType> builder;
 };
 
-enum class IterType { Parallel, Reduction };
-
-inline StringRef toString(IterType t) {
-  switch (t) {
-  case IterType::Parallel:
-    return getParallelIteratorTypeName();
-  case IterType::Reduction:
-    return getReductionIteratorTypeName();
-  }
-  llvm_unreachable("Unsupported IterType");
-}
-
-/// A StructuredIndexed represents an indexable quantity that is either:
-/// 1. a captured value, which is suitable for buffer and tensor operands, or;
-/// 2. a captured type, which is suitable for tensor return values.
-///
-/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`.
-/// It enable an idiomatic syntax for index expressions such as:
-///
-/// ```
-///      StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value),
-///        C(buffer_value_or_tensor_type);
-///      makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
-/// ```
-struct StructuredIndexed : public ValueHandle {
-  StructuredIndexed(Type type) : ValueHandle(type) {}
-  StructuredIndexed(Value value) : ValueHandle(value) {}
-  StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
-  StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
-    return StructuredIndexed(*this, indexings);
-  }
-
-  ArrayRef<AffineExpr> getExprs() { return exprs; }
-
-private:
-  StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
-      : ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
-    assert(t.isa<RankedTensorType>() && "RankedTensor expected");
-  }
-  StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
-      : ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
-    assert((v.getType().isa<MemRefType>() ||
-            v.getType().isa<RankedTensorType>()) &&
-           "MemRef or RankedTensor expected");
-  }
-  StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
-      : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
-
-  SmallVector<AffineExpr, 4> exprs;
-};
-
 inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}
 
 /// Build a `linalg.generic` op with the specified `inputs`, `outputs` and
@@ -157,7 +106,7 @@ inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}
 /// restriction output tensor results would need to be reordered, which would
 /// result in surprising behavior when combined with region definition.
 Operation *makeGenericLinalgOp(
-    ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
+    ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
     ArrayRef<StructuredIndexed> outputs,
     function_ref<void(ArrayRef<BlockArgument>)> regionBuilder =
         defaultRegionBuilder,

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index fbe97bce28c6..782a60c4e024 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -84,6 +84,19 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
   return res;
 }
 
+/// Typed representation for loop type strings.
+enum class IteratorType { Parallel, Reduction };
+
+inline StringRef toString(IteratorType t) {
+  switch (t) {
+  case IteratorType::Parallel:
+    return getParallelIteratorTypeName();
+  case IteratorType::Reduction:
+    return getReductionIteratorTypeName();
+  }
+  llvm_unreachable("Unsupported IteratorType");
+}
+
 } // end namespace mlir
 
 #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H

diff  --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index a96399cb2125..ab40e333a2c5 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/AffineOps/AffineOps.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Transforms/FoldUtils.h"
 
@@ -493,6 +494,46 @@ class BlockHandle : public CapturableHandle {
   mlir::Block *block;
 };
 
+/// A StructuredIndexed represents an indexable quantity that is either:
+/// 1. a captured value, which is suitable for buffer and tensor operands, or;
+/// 2. a captured type, which is suitable for tensor return values.
+///
+/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`.
+/// It enable an idiomatic syntax for index expressions such as:
+///
+/// ```
+///      StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value),
+///        C(buffer_value_or_tensor_type);
+///      makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
+/// ```
+struct StructuredIndexed : public ValueHandle {
+  StructuredIndexed(Type type) : ValueHandle(type) {}
+  StructuredIndexed(Value value) : ValueHandle(value) {}
+  StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
+  StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
+    return this->hasValue() ? StructuredIndexed(this->getValue(), indexings)
+                            : StructuredIndexed(this->getType(), indexings);
+  }
+
+  StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
+      : ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
+    assert(t.isa<RankedTensorType>() && "RankedTensor expected");
+  }
+  StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
+      : ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
+    assert((v.getType().isa<MemRefType>() ||
+            v.getType().isa<RankedTensorType>()) &&
+           "MemRef or RankedTensor expected");
+  }
+  StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
+      : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
+
+  ArrayRef<AffineExpr> getExprs() { return exprs; }
+
+private:
+  SmallVector<AffineExpr, 4> exprs;
+};
+
 template <typename Op, typename... Args>
 OperationHandle OperationHandle::create(Args... args) {
   return OperationHandle(ScopedContext::getBuilder()

diff  --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 35fd38aa49ec..7aaf6307a8e5 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/EDSC/Builders.h"
 #include "mlir/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
@@ -144,7 +145,7 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
 }
 
 Operation *mlir::edsc::makeGenericLinalgOp(
-    ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
+    ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
     ArrayRef<StructuredIndexed> outputs,
     function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
     ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
@@ -240,8 +241,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
 Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
                                              StructuredIndexed I,
                                              StructuredIndexed O) {
-  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
-                                           edsc::IterType::Parallel);
+  SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
+                                         IteratorType::Parallel);
   if (O.getType().isa<RankedTensorType>()) {
     auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
       assert(args.size() == 1 && "expected 1 block arguments");
@@ -270,8 +271,8 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
                                              StructuredIndexed I1,
                                              StructuredIndexed I2,
                                              StructuredIndexed O) {
-  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
-                                           edsc::IterType::Parallel);
+  SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
+                                         IteratorType::Parallel);
   if (O.getType().isa<RankedTensorType>()) {
     auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
       assert(args.size() == 2 && "expected 2 block arguments");
@@ -315,7 +316,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
   bindDims(ScopedContext::getContext(), m, n, k);
   StructuredIndexed A(vA), B(vB), C(vC);
   return makeGenericLinalgOp(
-    {IterType::Parallel, IterType::Parallel, IterType::Reduction},
+    {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
     {A({m, k}), B({k, n})},
     {C({m, n})},
     macRegionBuilder);
@@ -329,7 +330,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
   bindDims(ScopedContext::getContext(), m, n, k);
   StructuredIndexed A(vA), B(vB), C(tC);
   return makeGenericLinalgOp(
-    {IterType::Parallel, IterType::Parallel, IterType::Reduction},
+    {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
     {A({m, k}), B({k, n})},
     {C({m, n})},
     mulRegionBuilder);
@@ -343,7 +344,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
   bindDims(ScopedContext::getContext(), m, n, k);
   StructuredIndexed A(vA), B(vB), C(vC), D(tD);
   return makeGenericLinalgOp(
-    {IterType::Parallel, IterType::Parallel, IterType::Reduction},
+    {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
     {A({m, k}), B({k, n}), C({m, n})},
     {D({m, n})},
     macRegionBuilder);
@@ -360,8 +361,8 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
 
   // Some short names.
-  auto par = IterType::Parallel;
-  auto red = IterType::Reduction;
+  auto par = IteratorType::Parallel;
+  auto red = IteratorType::Reduction;
   auto s = strides;
   auto d = dilations;
 
@@ -393,8 +394,8 @@ Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
   assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
 
   // Some short names.
-  auto par = IterType::Parallel;
-  auto red = IterType::Reduction;
+  auto par = IteratorType::Parallel;
+  auto red = IteratorType::Reduction;
   auto s = strides;
   auto d = dilations;
 


        


More information about the Mlir-commits mailing list