[Mlir-commits] [mlir] 61ba9f9 - [mlir][Linalg] NFC - Extend the TilingInterface to allow better composition with out-of-tree dialects.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Dec 7 05:11:57 PST 2021


Author: Nicolas Vasilache
Date: 2021-12-07T13:06:27Z
New Revision: 61ba9f9110e230919982f2d82cdfcc4fe9840913

URL: https://github.com/llvm/llvm-project/commit/61ba9f9110e230919982f2d82cdfcc4fe9840913
DIFF: https://github.com/llvm/llvm-project/commit/61ba9f9110e230919982f2d82cdfcc4fe9840913.diff

LOG: [mlir][Linalg] NFC - Extend the TilingInterface to allow better composition with out-of-tree dialects.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D115233

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 50676b09508bb..c0c865842ce19 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -133,7 +133,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     [AttrSizedOperandSegments, NoSideEffect,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
      DeclareOpInterfaceMethods<TilingInterface,
-         ["getDestinationOperands", "getLoopIteratorTypes", "getLoopBounds",
+         ["getDestinationOperands", "getLoopIteratorTypes", "getIterationDomain",
           "getTiledImplementation"]>]> {
   let summary = "tensor pad operation";
   let description = [{

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index d82412b429212..6346899b39981 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -59,7 +59,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           step for the loops of the operation.
         }],
         /*retTy=*/"SmallVector<Range>",
-        /*methodName=*/"getLoopBounds",
+        /*methodName=*/"getIterationDomain",
         /*args=*/(ins "OpBuilder &":$b),
         /*methodBody=*/"",
         /*defaultImplementation=*/"return {};"
@@ -69,7 +69,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           Method to generate the tiled implementation of an operation.
 
           The iteration space of the operation is returned by
-          `getLoopBounds`. The caller provides the information of the
+          `getIterationDomain`. The caller provides the information of the
           tile within this iteration space whose implementation the
           caller needs.
           - `dest` are the Value into which the result of the tiled
@@ -79,20 +79,24 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           - `offsets` provides the offset of the tile within the
             iteration space
           - `sizes` provides the size of the tile.
+          - `tileDestOperands` specifies whether to also tile `dest` operands
+            or not. Avoiding tiling `dest` operands can be useful for 
+            composition with various looping container ops.
 
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"Operation *",
+        /*retType=*/"SmallVector<Operation *>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
             "ValueRange ":$dest,
             "ArrayRef<OpFoldResult> ":$offsets,
-            "ArrayRef<OpFoldResult> ":$sizes),
+            "ArrayRef<OpFoldResult> ":$sizes,
+            "bool ":$tileDestOperands),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          return nullptr;
+          return {};
         }]
       >
   ];

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 36828eabd59f7..a8abd275d61d4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1233,7 +1233,7 @@ SmallVector<StringRef> PadTensorOp::getLoopIteratorTypes() {
   return iteratorTypes;
 }
 
-SmallVector<Range> PadTensorOp::getLoopBounds(OpBuilder &b) {
+SmallVector<Range> PadTensorOp::getIterationDomain(OpBuilder &b) {
   ReifiedRankedShapedTypeDims reifiedShapes;
   (void)reifyResultShapes(b, reifiedShapes);
   Value zero = b.create<arith::ConstantIndexOp>(getLoc(), 0);
@@ -1246,13 +1246,13 @@ SmallVector<Range> PadTensorOp::getLoopBounds(OpBuilder &b) {
   return loopRanges;
 }
 
-Operation *PadTensorOp::getTiledImplementation(OpBuilder &b, ValueRange dest,
-                                               ArrayRef<OpFoldResult> offsets,
-                                               ArrayRef<OpFoldResult> sizes) {
+SmallVector<Operation *> PadTensorOp::getTiledImplementation(
+    OpBuilder &b, ValueRange dest, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, bool /*tileDestOperands*/) {
   // Only constant padding value supported.
   Value padValue = getConstantPaddingValue();
   if (!padValue)
-    return nullptr;
+    return {};
 
   // Helper variables and functions for various arithmetic operations. These are
   // used extensively for computing new offset/length and padding values.
@@ -1431,7 +1431,7 @@ Operation *PadTensorOp::getTiledImplementation(OpBuilder &b, ValueRange dest,
   // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
   // that the original data source x is not used.
   if (hasZeroLen) {
-    return createGenerateOp();
+    return {createGenerateOp()};
   }
 
   // If there are dynamic dimensions: Generate an scf.if check to avoid creating
@@ -1448,9 +1448,9 @@ Operation *PadTensorOp::getTiledImplementation(OpBuilder &b, ValueRange dest,
           b.create<scf::YieldOp>(loc,
                                  createPadTensorOfSubTensor()->getResult(0));
         });
-    return result;
+    return {result};
   }
-  return createPadTensorOfSubTensor();
+  return {createPadTensorOfSubTensor()};
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0a41153669083..c726063be04a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -351,7 +351,7 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
       options.tileSizeComputationFunction(builder, op);
   assert(static_cast<int64_t>(tileSizes.size()) == rank);
   // Compute lower and upper bounds of the loop nest.
-  SmallVector<Range> ranges = op.getLoopBounds(builder);
+  SmallVector<Range> ranges = op.getIterationDomain(builder);
   SmallVector<Value> lbs, dims, allDims, steps;
   for (int64_t i = 0; i < rank; ++i) {
     allDims.push_back(ranges[i].size);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e2d16cf285d39..4dbca8a308a1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -905,9 +905,12 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   if (!sliceOp.hasUnitStride())
     return failure();
 
-  Operation *tiledPadOp = padOp.getTiledImplementation(
-      rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
-      sliceOp.getMixedSizes());
+  Operation *tiledPadOp =
+      padOp
+          .getTiledImplementation(
+              rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
+              sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
+          .front();
   // All shapes are static and the data source is actually used. Rewrite into
   // pad_tensor(subtensor(x)).
   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());


        


More information about the Mlir-commits mailing list