[Mlir-commits] [mlir] 24199f5 - [mlir][linalg] Lower subtensor(pad_tensor) to pad_tensor(subtensor)

Matthias Springer llvmlistbot at llvm.org
Fri Jun 18 21:45:11 PDT 2021


Author: Matthias Springer
Date: 2021-06-19T13:44:47+09:00
New Revision: 24199f534f61d9ac7d2d9dcde7b9cac93c84d4f0

URL: https://github.com/llvm/llvm-project/commit/24199f534f61d9ac7d2d9dcde7b9cac93c84d4f0
DIFF: https://github.com/llvm/llvm-project/commit/24199f534f61d9ac7d2d9dcde7b9cac93c84d4f0.diff

LOG: [mlir][linalg] Lower subtensor(pad_tensor) to pad_tensor(subtensor)

Only high padding is supported at the moment. Low padding will be added in a separate commit.

Differential Revision: https://reviews.llvm.org/D104357

Added: 
    mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9df8fbb2e4693..8841af104a360 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1076,6 +1076,15 @@ LogicalResult applyStagedPatterns(
     const FrozenRewritePatternSet &stage2Patterns,
     function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
 
+/// Rewrite subtensor(pad_tensor(x)) into pad_tensor(subtensor(x)).
+struct SubTensorOfPadTensorSwapPattern
+    : public OpRewritePattern<SubTensorOp> {
+  using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index efd0c3b2079d1..4c2df05f52cb1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -700,3 +700,225 @@ LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
 
   return success();
 }
+
+/// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute,
+/// it must be of type Integer.
+static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
+  if (auto val = ofr.dyn_cast<Value>())
+    return val;
+  auto intVal = getConstantIntValue(ofr);
+  assert(intVal && "expected Value or IntegerAttr");
+  return builder.create<ConstantIndexOp>(loc, *intVal);
+}
+
+/// Given a value, try to extract a constant index-type integer as an Attribute.
+/// If this fails, return the original value.
+static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) {
+  if (auto constInt = getConstantIntValue(val))
+    return builder.getIndexAttr(*constInt);
+  return val;
+}
+
+LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
+    SubTensorOp subTensorOp, PatternRewriter &rewriter) const {
+  auto padOp = subTensorOp.source().getDefiningOp<PadTensorOp>();
+  if (!padOp)
+    return failure();
+  // Only unit stride supported.
+  if (!subTensorOp.hasUnitStride())
+    return failure();
+  // Only constant padding value supported.
+  Value padValue = padOp.getConstantPaddingValue();
+  if (!padValue)
+    return failure();
+  // Only zero low padding supported at the moment.
+  if (!padOp.hasZeroLowPad())
+    return failure();
+
+  // Helper variables and functions for various arithmetic operations. These are
+  // used extensively for computing new offset/length and padding values.
+  Location loc = subTensorOp.getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  // Add two integers.
+  auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
+  auto add = [&](Value v1, Value v2) {
+    return rewriter.createOrFold<AffineApplyOp>(loc, addMap,
+                                                ValueRange{v1, v2});
+  };
+  // Subtract two integers.
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](Value v1, Value v2) {
+    return rewriter.createOrFold<AffineApplyOp>(loc, subMap,
+                                                ValueRange{v1, v2});
+  };
+  // Take the minimum of two integers.
+  auto idMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
+  auto min = [&](Value v1, Value v2) {
+    return rewriter.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
+  };
+  // Take the maximum of two integers.
+  auto max = [&](Value v1, Value v2) {
+    return rewriter.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
+  };
+  // Zero index-typed integer.
+  auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
+
+  // Helper function for filling static/dynamic low/high padding indices vectors
+  // of PadTensorOp.
+  auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
+                         SmallVector<int64_t> &staticIndices) {
+    if (auto constInt = getConstantIntValue(val)) {
+      staticIndices.push_back(*constInt);
+    } else {
+      staticIndices.push_back(ShapedType::kDynamicSize);
+      dynIndices.push_back(val);
+    }
+  };
+
+  // Compute new offsets, lengths, low padding, high padding.
+  SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+  SmallVector<Value> newLows, newHighs;
+  SmallVector<int64_t> staticNewLows, staticNewHighs;
+  // Set to true if the original data source is not read at all.
+  bool hasZeroLen = false;
+  // Same as hasZeroLen, but for dynamic dimension sizes. This condition
+  // is true if the original data source turns out to be unused at runtime.
+  Value dynHasZeroLenCond;
+
+  int64_t rank = padOp.getSourceType().getRank();
+  for (unsigned dim = 0; dim < rank; ++dim) {
+    auto offset = asValue(rewriter, loc, subTensorOp.getMixedOffsets()[dim]);
+    auto length = asValue(rewriter, loc, subTensorOp.getMixedSizes()[dim]);
+    auto srcSize = rewriter.createOrFold<memref::DimOp>(
+        loc, padOp.source(), dim);
+
+    // Existing low padding is zero, so new low padding is also zero.
+    Value newLow = zero;
+    appendIndex(newLow, newLows, staticNewLows);
+
+    // There is no low padding, so the offset remains unchanged. Except for the
+    // case where the SubTensorOp starts reading from a position within the high
+    // padding. In that case, set the offset to the end of source tensor. The
+    // new SubTensorOp length will be zero in that case. (Effectively reading no
+    // data from the source.)
+    Value newOffset = min(offset, srcSize);
+    newOffsets.push_back(asOpFoldResult(rewriter, newOffset));
+
+    // The new SubTensorOp starts reading at `newOffset` and reads until
+    // `offset + length`. This position may be outside of the source (i.e.,
+    // within the high padding). In that case, read only until the end of the
+    // source. In mathematical terms:
+    //
+    // endLoc = min(offset + length, srcSize)
+    //
+    // The new SubTensorOp length is `endLoc - newOffset`.
+    Value newLength = sub(min(add(offset, length), srcSize), newOffset);
+    newLengths.push_back(asOpFoldResult(rewriter, newLength));
+    if (auto newLengthInt = getConstantIntValue(newLength)) {
+      hasZeroLen |= *newLengthInt == 0;
+    } else {
+      Value check = rewriter.create<CmpIOp>(
+          loc, CmpIPredicate::eq, newLength, zero);
+      dynHasZeroLenCond = dynHasZeroLenCond
+          ? rewriter.create<AndOp>(loc, check, dynHasZeroLenCond) : check;
+    }
+
+    // The number of elements available to read from the source (starting from
+    // the new offset) is `maxRead = srcSize - newOffset`. The original
+    // SubTensorOp may have read a larger number of elements `length > maxRead`.
+    // In that case, the missing number of elements `length - maxRead` must be
+    // paddded. (If `maxRead > length`, more than enough data is available to
+    // read and no high padding is needed.)
+    Value newHigh = max(zero, add(sub(newOffset, srcSize), length));
+    appendIndex(newHigh, newHighs, staticNewHighs);
+
+    // Only unit stride supported.
+    newStrides.push_back(rewriter.getIndexAttr(1));
+  }
+
+  // Insert cast to ensure that types match. (May be folded away.)
+  auto castResult = [&](Value val) -> Value {
+    auto castOp = rewriter.create<tensor::CastOp>(
+        loc, subTensorOp.getType(), val);
+    return castOp;
+  };
+
+  // In cases where the original data source is unused: Emit a GenerateOp and
+  // do not generate a SubTensorOp. (The result shape of the SubTensorOp would
+  // have a dimension of size 0, the semantics of which is unclear.)
+  auto createGenerateOp = [&]() {
+    // The shape of the GenerateOp is the same as the existing SubTensorOp.
+    RankedTensorType type = subTensorOp.getType();
+    SmallVector<Value> dynDims;
+    for (unsigned i = 0; i < type.getRank(); ++i) {
+      if (type.isDynamicDim(i))
+        dynDims.push_back(
+            asValue(rewriter, loc, subTensorOp.getMixedOffsets()[i]));
+    }
+
+    // Create GenerateOp.
+    auto generateOp  = rewriter.create<tensor::GenerateOp>(loc, type, dynDims);
+
+    // Copy region to new op.
+    BlockAndValueMapping bvm;
+    padOp.region().cloneInto(&generateOp.getRegion(), bvm);
+    // Rewrite linalg::YieldOp to tensor::YieldOp.
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      auto yieldOp = dyn_cast<linalg::YieldOp>(
+          generateOp.getRegion().front().getTerminator());
+      assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
+      assert(yieldOp.values().size() == 1);
+      rewriter.setInsertionPoint(yieldOp);
+      rewriter.replaceOpWithNewOp<tensor::YieldOp>(
+          yieldOp, yieldOp.values()[0]);
+    }
+
+    return castResult(generateOp);
+  };
+
+  // Emit a SubTensorOp and a PadTensorOp. Should not be used in cases where
+  // the result shape of the new SubTensorOp has a zero dimension.
+  auto createPadTensorOfSubTensor = [&]() {
+    // Create pad_tensor(subtensor(x)).
+    auto newSubTensorOp = rewriter.create<SubTensorOp>(
+        loc, padOp.source(), newOffsets, newLengths, newStrides);
+    auto newPadTensorOp = rewriter.create<PadTensorOp>(
+        loc, newSubTensorOp, staticNewLows, staticNewHighs, newLows, newHighs);
+
+    // Copy region to new PadTensorOp.
+    BlockAndValueMapping bvm;
+    padOp.region().cloneInto(&newPadTensorOp.getRegion(), bvm);
+
+    // Cast result and return.
+    return castResult(newPadTensorOp);
+  };
+
+  // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
+  // that the original data source x is not used.
+  if (hasZeroLen) {
+    rewriter.replaceOp(subTensorOp, createGenerateOp());
+    return success();
+  }
+
+  // If there are dynamic dimensions: Generate an scf.if check to avoid creating
+  // SubTensorOps with result dimensions of size 0 at runtime.
+  if (dynHasZeroLenCond) {
+    auto result = rewriter.create<scf::IfOp>(
+        loc, subTensorOp.getType(), dynHasZeroLenCond,
+        /*thenBuilder=*/[&](OpBuilder &b, Location loc) {
+          b.create<scf::YieldOp>(loc, createGenerateOp());
+        },
+        /*elseBuilder=*/[&](OpBuilder &b, Location loc) {
+          b.create<scf::YieldOp>(loc, createPadTensorOfSubTensor());
+        });
+    rewriter.replaceOp(subTensorOp, result.getResult(0));
+    return success();
+  }
+
+  // All shapes are static and the data source is actually used. Rewrite into
+  // pad_tensor(subtensor(x)).
+  rewriter.replaceOp(subTensorOp, createPadTensorOfSubTensor());
+  return success();
+}

diff  --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
new file mode 100644
index 0000000000000..7d9c770946e63
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-swap-subtensor-padtensor -canonicalize  -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @static_data_only(
+//  CHECK-SAME:     %[[ARG0:.*]]: tensor<4x5xf32>
+//       CHECK:   %[[RESULT:.*]] = subtensor %[[ARG0]][1, 2] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
+//       CHECK:   return %[[RESULT]]
+func @static_data_only(%arg0 : tensor<4x5xf32>, %pad : f32)
+    -> tensor<2x1xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad : f32
+    } : tensor<4x5xf32> to tensor<11x13xf32>
+  %1 = subtensor %0[1, 2] [2, 1] [1, 1] : tensor<11x13xf32> to tensor<2x1xf32>
+  return %1 : tensor<2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @static_high_pad_only
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
+//   CHECK-NOT:   linalg.pad_tensor
+//   CHECK-NOT:   subtensor
+//       CHECK:   %[[RESULT:.*]] = tensor.generate
+//       CHECK:     tensor.yield %[[PAD]]
+//       CHECK:   return %[[RESULT]] : tensor<2x4xf32>
+func @static_high_pad_only(%arg0 : tensor<4x5xf32>, %pad : f32)
+    -> tensor<2x4xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad : f32
+    } : tensor<4x5xf32> to tensor<11x13xf32>
+  %1 = subtensor %0[4, 5] [2, 4] [1, 1] : tensor<11x13xf32> to tensor<2x4xf32>
+  return %1 : tensor<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @static_mixed_data_high_pad
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
+//   CHECK-NOT:   linalg.pad_tensor
+//       CHECK:   %[[SUBTENSOR:.*]] = subtensor %[[ARG0]][2, 4] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
+//       CHECK:   %[[RESULT:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[1, 3]
+//       CHECK:     linalg.yield %[[PAD]]
+//       CHECK:   return %[[RESULT]] : tensor<3x4xf32>
+func @static_mixed_data_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
+    -> tensor<3x4xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad : f32
+    } : tensor<4x5xf32> to tensor<11x13xf32>
+  %1 = subtensor %0[2, 4] [3, 4] [1, 1] : tensor<11x13xf32> to tensor<3x4xf32>
+  return %1 : tensor<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dynamic_high_pad
+//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x5xf32>
+//   CHECK-NOT:   linalg.pad_tensor
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   memref.dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[RESULT:.*]] = scf.if %{{.*}} -> (tensor<3x4xf32>) {
+//       CHECK:     %[[GEN:.*]] = tensor.generate
+//       CHECK:     scf.yield %[[GEN]]
+//       CHECK:   } else {
+//       CHECK:     %[[SUBTENSOR:.*]] = subtensor %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
+//       CHECK:     %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
+//       CHECK:     %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor<?x4xf32> to tensor<3x4xf32>
+//       CHECK:     scf.yield %[[CAST]]
+//       CHECK:   }
+//       CHECK:   return %[[RESULT]]
+func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[%h1, 8] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad : f32
+    } : tensor<?x5xf32> to tensor<?x13xf32>
+  %1 = subtensor %0[2, 4] [3, 4] [1, 1] : tensor<?x13xf32> to tensor<3x4xf32>
+  return %1 : tensor<3x4xf32>
+}
+

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 0037db27610ef..402a26475a503 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -97,6 +97,11 @@ struct TestLinalgTransforms
       *this, "test-transform-pad-tensor",
       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
       llvm::cl::init(false)};
+  Option<bool> testSwapSubTensorPadTensor{
+      *this, "test-swap-subtensor-padtensor",
+      llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
+                     "pad_tensor(subtensor)"),
+      llvm::cl::init(false)};
   ListOption<int64_t> tileSizesForPadding{
       *this, "tile-sizes-for-padding",
       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
@@ -524,6 +529,12 @@ static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applySubTensorOfPadTensorSwapPattern(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  patterns.add<SubTensorOfPadTensorSwapPattern>(funcOp.getContext());
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
   RewritePatternSet foldPattern(funcOp.getContext());
   foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
@@ -602,6 +613,8 @@ void TestLinalgTransforms::runOnFunction() {
     return applyLinalgToVectorPatterns(getFunction());
   if (testTransformPadTensor)
     return applyPadTensorToGenericPatterns(getFunction());
+  if (testSwapSubTensorPadTensor)
+    return applySubTensorOfPadTensorSwapPattern(getFunction());
   if (testAffineMinSCFCanonicalizationPatterns)
     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
   if (testTileAndPadPattern)


        


More information about the Mlir-commits mailing list