[Mlir-commits] [mlir] 6b64957 - [mlir][Linalg] Refactor Linalg op initTensors support - NFC

Nicolas Vasilache llvmlistbot at llvm.org
Tue Sep 29 06:57:08 PDT 2020


Author: Nicolas Vasilache
Date: 2020-09-29T09:56:38-04:00
New Revision: 6b649570cbc44dd775d9657805cc60b2075d8011

URL: https://github.com/llvm/llvm-project/commit/6b649570cbc44dd775d9657805cc60b2075d8011
DIFF: https://github.com/llvm/llvm-project/commit/6b649570cbc44dd775d9657805cc60b2075d8011.diff

LOG: [mlir][Linalg] Refactor Linalg op initTensors support - NFC

Manually-defined named ops do not currently support `init_tensors` or return values and may never support them. Add extra interface to the StructuredOpInterface so that we can still write op-agnostic transformations based on StructuredOpInterface.

This is an NFC extension in preparation for tiling on tensors.

Differential Revision: https://reviews.llvm.org/D88481

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ed87689822e5..d12322933737 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -22,14 +22,19 @@ include "mlir/Interfaces/CopyOpInterface.td"
 // The Linalg `NInputs` trait provides the API for ops that are known
 // to have a specified number of inputs, all passed as operands.
 // See Linalg/LinalgTraits.h for implementation details and usage.
-class NInputs<int args_in> :
-  NativeOpTrait<"linalg::NInputs<" # !cast<string>(args_in) # ">::Impl"> {}
+class NInputs<int n> :
+  NativeOpTrait<"linalg::NInputs<" # !cast<string>(n) # ">::Impl"> {}
+
+// The Linalg `ZeroInitTensors` trait provides the API for ops that are known
+// to not have input tensor operands.
+// See Linalg/LinalgTraits.h for implementation details and usage.
+def ZeroInitTensors : NativeOpTrait<"linalg::ZeroInitTensors"> {}
 
 // The Linalg `NOutputs` trait provides the API for ops that are known
 // to have a specified number of outputs, all passed as operands.
 // See Linalg/LinalgTraits.h for implementation details and usage.
-class NOutputs<int args_out> :
-  NativeOpTrait<"linalg::NOutputs<" # !cast<string>(args_out) # ">::Impl"> {}
+class NOutputs<int n> :
+  NativeOpTrait<"linalg::NOutputs<" # !cast<string>(n) # ">::Impl"> {}
 
 def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
 def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
@@ -62,6 +67,7 @@ class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
 def CopyOp : LinalgStructured_Op<"copy", [
     CopyOpInterface,
     NInputs<1>,
+    ZeroInitTensors,
     NOutputs<1>
   ]> {
   let description = [{
@@ -159,7 +165,10 @@ def CopyOp : LinalgStructured_Op<"copy", [
   let hasCanonicalizer = 1;
 }
 
-def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
+def FillOp : LinalgStructured_Op<"fill", [
+    NInputs<0>,
+    ZeroInitTensors,
+    NOutputs<1>]> {
 
   let arguments = (ins AnyStridedMemRef:$output,
                    AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
@@ -254,7 +263,12 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
   }];
 }
 
-def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
+def ConvOp : PoolingBase_Op<"conv", [
+    NInputs<2>,
+    // Despite having reductions, this manually defined ConvOp may only take
+    // memref operands and can never have init tensors.
+    ZeroInitTensors,
+    NOutputs<1>]> {
 
   let description = [{
     Generic n-D convolution as described in the TF documentation:
@@ -371,7 +385,12 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
 }
 
 class SingleInputPoolingBase_Op<string mnemonic>
-  : PoolingBase_Op<mnemonic, [NInputs<2>, NOutputs<1>]> {
+  : PoolingBase_Op<mnemonic, [
+    NInputs<2>,
+    // Despite having reductions, this manually defined ConvOp may only take
+    // memref operands and can never have init tensors.
+    ZeroInitTensors,
+    NOutputs<1>]> {
   let description = [{
     A base class for single input pooling function.
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 17e16a15d39a..23d296c392ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -125,13 +125,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
                getNumIterators(getReductionIteratorTypeName(), iters) == 1;
       }]>,
     //===------------------------------------------------------------------===//
-    // Num input/output arguments handling.
+    // Num input/output/initTensors arguments handling.
     //===------------------------------------------------------------------===//
     // These special methods must be defined by each op that wants to implement
     // the LinalgStructuredInterface. For now, this is either:
-    // - inherited statically by using the NInputs<unsigned> or
-    //   NOutputs<unsigned> traits.
-    // - derived from args_in/args_out attributes (for linalg.generic and
+    // - Explicitly specified in the op definition.
+    // - Derived from variadic attributes (for "named" ops, linalg.generic and
     //   linalg.indexed_generic ops).
     InterfaceMethod<
       /*desc=*/[{
@@ -140,6 +139,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*retTy=*/"unsigned",
       /*methodName=*/"getNumInputs"
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of init tensors.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getNumInitTensors"
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the number of outputs.
@@ -371,6 +377,46 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the range over init tensors.
+      }],
+      /*retTy=*/"Operation::operand_range",
+      /*methodName=*/"getInitTensors",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = this->getOperation()->getOperands();
+        return {range.begin() + getNumInputsAndOutputBuffers(),
+                range.begin() + getNumInputsAndOutputs()};
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return one single init tensor at position `$i`.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getInitTensor",
+      /*args=*/(ins "unsigned":$i),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(i < $_op.getNumInitTensors() && "overflowing init tensor index");
+        return getInitTensors()[i];
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the range over inputs, output buffers and init tensors.
+      }],
+      /*retTy=*/"Operation::operand_range",
+      /*methodName=*/"getShapedOperands",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = this->getOperation()->getOperands();
+        return {range.begin(), range.begin() + getNumInputsAndOutputs()};
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the `i`-th shaped type, there are 3 cases:
@@ -445,7 +491,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange<AffineMapAttr>());
+        return llvm::to_vector<4>(
+          $_op.indexing_maps().template getAsValueRange<AffineMapAttr>());
       }]
     >,
     InterfaceMethod<
@@ -528,11 +575,11 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       }],
       /*retTy=*/"Operation *",
       /*methodName=*/"create",
-      (ins "OpBuilder &":$builder, "Location":$loc,
+      (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
            "ValueRange":$operands,
            "ArrayRef<NamedAttribute>":$attributes), [{
-        return builder.create<ConcreteOp>(loc, TypeRange{}, operands,
-                                          attributes);
+        return builder.create<ConcreteOp>(
+          loc, resultTypes, operands, attributes);
       }]
     >,
     InterfaceMethod<
@@ -542,10 +589,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       }],
       /*retTy=*/"Operation *",
       /*methodName=*/"clone",
-      (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{
+      (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
+           "ValueRange":$operands),
+      [{
         BlockAndValueMapping map;
         unsigned numRegions = $_op.getOperation()->getNumRegions();
-        Operation *res = create(b, loc, operands, $_op.getAttrs());
+        Operation *res = create(b, loc, resultTypes, operands, $_op.getAttrs());
         assert(res->getNumRegions() == numRegions && "inconsistent # regions");
         for (unsigned ridx = 0; ridx < numRegions; ++ridx)
           $_op.getOperation()->getRegion(ridx).cloneInto(

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index 1df2b21bdade..5f1c756ca446 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -35,6 +35,17 @@ template <unsigned N> class NInputs {
   };
 };
 
+/// This class provides the API for ops that are known to not have init tensor
+/// operands. Use as a trait as follows:
+///
+///   class CopyOp : public Op<CopyOp, OpTrait::ZeroInitTensors> {
+///
+template <typename ConcreteType>
+class ZeroInitTensors : public TraitBase<ConcreteType, ZeroInitTensors> {
+public:
+  static unsigned getNumInitTensors() { return 0; }
+};
+
 /// This class provides the API for ops that are known to have a specified
 /// number of outputs, all passed as operands. Use as a trait as follows:
 ///
@@ -87,6 +98,9 @@ class NamedStructuredOpTrait
   unsigned getNumInputs() {
     return cast<ConcreteType>(this->getOperation()).inputs().size();
   }
+  unsigned getNumInitTensors() {
+    return cast<ConcreteType>(this->getOperation()).init_tensors().size();
+  }
   unsigned getNumOutputs() {
     ConcreteType concreteOp = cast<ConcreteType>(this->getOperation());
     return concreteOp.output_buffers().size() +

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 04d417480f3b..dfc977daa207 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -99,7 +99,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
   auto operands = getAssumedNonViewOperands(op);
   clonedViews.append(operands.begin(), operands.end());
 
-  Operation *clonedOp = op.clone(b, loc, clonedViews);
+  Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews);
   // When the producer is an IndexedGenercOp, we have to transform its block
   // IV arguments according to the tiling of the consumer, i.e. offset them by
   // the values computed in `loopRanges`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 676caa145c3a..3db801bc2d57 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -405,7 +405,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
                                     tileSizes, allViewSizes);
         auto operands = getAssumedNonViewOperands(op);
         views.append(operands.begin(), operands.end());
-        res = op.clone(b, loc, views);
+        res = op.clone(b, loc, /*resultTypes*/ {}, views);
         return scf::ValueVector{};
       },
       options.distribution);


        


More information about the Mlir-commits mailing list