[Mlir-commits] [mlir] 0ac3d97 - [mlir][Linalg] Fix pad hoisting.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Feb 10 08:54:24 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-10T16:49:38Z
New Revision: 0ac3d97bf4943f8d0d0fedf3545cdb099dadcb1f

URL: https://github.com/llvm/llvm-project/commit/0ac3d97bf4943f8d0d0fedf3545cdb099dadcb1f
DIFF: https://github.com/llvm/llvm-project/commit/0ac3d97bf4943f8d0d0fedf3545cdb099dadcb1f.diff

LOG: [mlir][Linalg] Fix pad hoisting.

This revision fixes the indexing logic into the packed tensor that result from hoisting padding. Previously, the index was incorrectly set to the loop induction variable when in fact we need to compute the iteration count (i.e. `(iv - lb).ceilDiv(step)`).

Differential Revision: https://reviews.llvm.org/D96417

Added: 
    

Modified: 
    mlir/include/mlir/IR/AffineExpr.h
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoist-padding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 3e4e1c014b58..e71448716930 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -79,10 +79,14 @@ class AffineExpr {
 
   bool operator!() const { return expr == nullptr; }
 
-  template <typename U> bool isa() const;
-  template <typename U> U dyn_cast() const;
-  template <typename U> U dyn_cast_or_null() const;
-  template <typename U> U cast() const;
+  template <typename U>
+  bool isa() const;
+  template <typename U>
+  U dyn_cast() const;
+  template <typename U>
+  U dyn_cast_or_null() const;
+  template <typename U>
+  U cast() const;
 
   MLIRContext *getContext() const;
 
@@ -251,7 +255,8 @@ AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
 
 raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
 
-template <typename U> bool AffineExpr::isa() const {
+template <typename U>
+bool AffineExpr::isa() const {
   if (std::is_same<U, AffineBinaryOpExpr>::value)
     return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
   if (std::is_same<U, AffineDimExpr>::value)
@@ -261,15 +266,18 @@ template <typename U> bool AffineExpr::isa() const {
   if (std::is_same<U, AffineConstantExpr>::value)
     return getKind() == AffineExprKind::Constant;
 }
-template <typename U> U AffineExpr::dyn_cast() const {
+template <typename U>
+U AffineExpr::dyn_cast() const {
   if (isa<U>())
     return U(expr);
   return U(nullptr);
 }
-template <typename U> U AffineExpr::dyn_cast_or_null() const {
+template <typename U>
+U AffineExpr::dyn_cast_or_null() const {
   return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
 }
-template <typename U> U AffineExpr::cast() const {
+template <typename U>
+U AffineExpr::cast() const {
   assert(isa<U>());
   return U(expr);
 }
@@ -282,28 +290,46 @@ AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
                               unsigned numSymbols);
 
 namespace detail {
-template <int N> void bindDims(MLIRContext *ctx) {}
+template <int N>
+void bindDims(MLIRContext *ctx) {}
 
 template <int N, typename AffineExprTy, typename... AffineExprTy2>
-void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &... exprs) {
+void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
   e = getAffineDimExpr(N, ctx);
   bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
 }
+
+template <int N>
+void bindSymbols(MLIRContext *ctx) {}
+
+template <int N, typename AffineExprTy, typename... AffineExprTy2>
+void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
+  e = getAffineSymbolExpr(N, ctx);
+  bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
+}
 } // namespace detail
 
 /// Bind a list of AffineExpr references to DimExpr at positions:
 ///   [0 .. sizeof...(exprs)]
 template <typename... AffineExprTy>
-void bindDims(MLIRContext *ctx, AffineExprTy &... exprs) {
+void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
   detail::bindDims<0>(ctx, exprs...);
 }
 
+/// Bind a list of AffineExpr references to SymbolExpr at positions:
+///   [0 .. sizeof...(exprs)]
+template <typename... AffineExprTy>
+void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
+  detail::bindSymbols<0>(ctx, exprs...);
+}
+
 } // namespace mlir
 
 namespace llvm {
 
 // AffineExpr hash just like pointers
-template <> struct DenseMapInfo<mlir::AffineExpr> {
+template <>
+struct DenseMapInfo<mlir::AffineExpr> {
   static mlir::AffineExpr getEmptyKey() {
     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
     return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 699be7817e0e..f3d98f634788 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -417,16 +417,28 @@ hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
   return success();
 }
 
-static Value buildLoopTripCount(OpBuilder &b, Operation *op) {
-  MLIRContext *ctx = op->getContext();
-  AffineExpr lb, ub, step = getAffineSymbolExpr(0, ctx);
+/// Return the number of iterations in the loop (ub - lb).ceilDiv(step).
+static Value buildLoopTripCount(OpBuilder &b, scf::ForOp forOp) {
+  MLIRContext *ctx = forOp->getContext();
+  AffineExpr lb, ub, step;
   bindDims(ctx, lb, ub);
-  scf::ForOp forOp = cast<scf::ForOp>(op);
+  bindSymbols(ctx, step);
   return b.create<AffineApplyOp>(
-      op->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx),
+      forOp->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx),
       ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()});
 }
 
+/// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
+static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp forOp) {
+  MLIRContext *ctx = forOp->getContext();
+  AffineExpr iv, lb, step;
+  bindDims(ctx, iv, lb);
+  bindSymbols(ctx, step);
+  return b.create<AffineApplyOp>(
+      forOp->getLoc(), AffineMap::get(2, 1, {(iv - lb).ceilDiv(step)}, ctx),
+      ValueRange{forOp.getInductionVar(), forOp.lowerBound(), forOp.step()});
+}
+
 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
                                                   unsigned nLoops) {
   llvm::SetVector<Operation *> backwardSlice, packingLoops;
@@ -455,8 +467,10 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
   llvm::append_range(packedShape, paddedTensorType.getShape());
   auto packedTensorType =
       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
-  auto dynamicSizes = llvm::to_vector<4>(llvm::map_range(
-      packingLoops, [&](Operation *op) { return buildLoopTripCount(b, op); }));
+  auto dynamicSizes =
+      llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) {
+        return buildLoopTripCount(b, cast<scf::ForOp>(op));
+      }));
   Value packedTensor = b.create<linalg::InitTensorOp>(
       loc, dynamicSizes, packedTensorType.getShape(),
       packedTensorType.getElementType());
@@ -469,8 +483,9 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
   //   2. Create a SubTensorInsert at the top of the stack.
   //   3. Iteratively pop and yield the result of the SubTensorInsertOp across
   //     the cloned loops.
-  SmallVector<Value> clonedLoopIvs;
+  SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
   clonedLoopIvs.reserve(nLoops);
+  leadingPackedTensorIndexings.reserve(nLoops);
   BlockAndValueMapping bvm;
   // Stack step 1. iteratively clone loops and push `packedTensor`.
   // Insert `padTensorOp` into the backwardSlice so we clone it too.
@@ -492,13 +507,16 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
     assert(clonedForOp->getNumRegions() == 1);
     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
+    leadingPackedTensorIndexings.push_back(
+        buildLoopIterationCount(b, clonedForOp));
     bvm.map(forOp.getInductionVar(), clonedLoopIvs.back());
     packedTensor = clonedForOp.getRegionIterArgs().front();
   }
 
   // Stack step 2. create SubTensorInsertOp at the top of the stack.
   // offsets = [clonedLoopIvs, 0 .. 0].
-  SmallVector<OpFoldResult> offsets(clonedLoopIvs.begin(), clonedLoopIvs.end());
+  SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
+                                    leadingPackedTensorIndexings.end());
   offsets.append(paddedRank, b.getIndexAttr(0));
   // sizes = [1 .. 1, paddedShape].
   SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1));
@@ -527,12 +545,12 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
   // Now the packed tensor is ready, replace the original padding op by a
   // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
   b.setInsertionPoint(padTensorOp);
-  SmallVector<Value> originalLoopIvs =
-      llvm::to_vector<4>(llvm::map_range(packingLoops, [](Operation *loop) {
-        return cast<scf::ForOp>(loop).getInductionVar();
+  SmallVector<Value> loopIterationCounts =
+      llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
+        return buildLoopIterationCount(b, cast<scf::ForOp>(loop));
       }));
   // offsets = [originalLoopIvs, 0 .. 0].
-  offsets.assign(originalLoopIvs.begin(), originalLoopIvs.end());
+  offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
   offsets.append(paddedRank, b.getIndexAttr(0));
   // sizes = [1 .. 1, paddedShape] (definedabove).
   // strides = [1 .. 1] (defined above)

diff  --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir
index 8685df44db3f..77860928f41e 100644
--- a/mlir/test/Dialect/Linalg/hoist-padding.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir
@@ -6,7 +6,15 @@
 #map3 = affine_map<(d0, d1) -> (2, d0 - d1)>
 #map4 = affine_map<(d0, d1) -> (3, d0 - d1)>
 
+// CHECK-DAG: #[[$DIV3:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 3)>
+// CHECK-DAG: #[[$DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)>
+// CHECK-DAG: #[[$DIVS3:[0-9a-z]+]] = affine_map<()[s0] -> (s0 ceildiv 3)>
+// CHECK-DAG: #[[$DIVS4:[0-9a-z]+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+
 // CHECK-LABEL: func @matmul_tensors
+//  CHECK-SAME:   %[[TA:[0-9a-z]+]]: tensor
+//  CHECK-SAME:   %[[TB:[0-9a-z]+]]: tensor
+//  CHECK-SAME:   %[[TC:[0-9a-z]+]]: tensor
 func @matmul_tensors(
   %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
   -> tensor<?x?xf32>
@@ -15,39 +23,60 @@ func @matmul_tensors(
   %c3 = constant 3 : index
   %c4 = constant 4 : index
   %cst = constant 0.000000e+00 : f32
+
+  //  CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  //  CHECK-DAG: %[[C1:.*]] = constant 1 : index
   %c0 = constant 0 : index
   %c1 = constant 1 : index
+
+  //  CHECK-DAG: %[[dM:.*]] = dim %[[TA]], %[[C0]] : tensor<?x?xf32>
+  //  CHECK-DAG: %[[dK:.*]] = dim %[[TA]], %[[C1]] : tensor<?x?xf32>
+  //  CHECK-DAG: %[[dN:.*]] = dim %[[TB]], %[[C1]] : tensor<?x?xf32>
   %0 = dim %arg0, %c0 : tensor<?x?xf32>
   %1 = dim %arg0, %c1 : tensor<?x?xf32>
   %2 = dim %arg1, %c1 : tensor<?x?xf32>
 
-  //      CHECK: scf.for
-  //      CHECK:   linalg.init_tensor [%{{.*}}, 2, 4] : tensor<?x2x4xf32>
+  //      CHECK: scf.for %[[I:[0-9a-z]+]] =
+  // First padded tensor is MxKx2x4 under loop M so Kx2x4
+  //      CHECK:   %[[SZpad0_K:[0-9]+]] = affine.apply #[[$DIVS4]]()[%[[dK]]]
+  //      CHECK:   linalg.init_tensor [%[[SZpad0_K]], 2, 4] : tensor<?x2x4xf32>
   // 1-D loop
-  //      CHECK:   %[[A:.*]] = scf.for
-  //  CHECK-NOT:     scf.for
+  //      CHECK:   %[[A:.*]] = scf.for %[[J1:[0-9a-z]+]] =
+  // Iteration count along J1
+  //      CHECK:     %[[IDXpad0_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[J1]])
   //      CHECK:     subtensor %{{.*}} [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   //      CHECK:     linalg.pad_tensor %{{.*}}
   //      CHECK:       : tensor<?x?xf32> to tensor<2x4xf32>
-  //      CHECK:     subtensor_insert %{{.*}} into %{{.*}}[%{{.*}}, 0, 0]
+  //      CHECK:     subtensor_insert %{{.*}} into %{{.*}}[%[[IDXpad0_K]], 0, 0]
   // CHECK-SAME:       [1, 2, 4] [1, 1, 1] : tensor<2x4xf32> into tensor<?x2x4xf32>
+  // Second padded tensor is KxNx2x4
+  //      CHECK:   %[[SZpad1_K:[0-9]+]] = affine.apply #[[$DIVS4]]()[%[[dK]]]
+  //      CHECK:   %[[SZpad1_N:[0-9]+]] = affine.apply #[[$DIVS3]]()[%[[dN]]]
+  //      CHECK:   linalg.init_tensor [%[[SZpad1_K]], %[[SZpad1_N]], 4, 3] : tensor<?x?x4x3xf32>
   // 2-D loop
-  //      CHECK:   linalg.init_tensor [%{{.*}}, %{{.*}}, 4, 3] : tensor<?x?x4x3xf32>
-  //      CHECK:   %[[B:.*]] = scf.for
-  //      CHECK:     scf.for
-  //  CHECK-NOT:       scf.for
+  //      CHECK:   %[[B:.*]] = scf.for %[[K2:[0-9a-z]+]] =
+  // Iteration count along K2
+  //      CHECK:     %[[IDXpad1_K:[0-9]+]] = affine.apply #[[$DIV3]](%[[K2]])
+  //      CHECK:     scf.for %[[J2:[0-9a-z]+]] =
+  // Iteration count along J2
+  //      CHECK:       %[[IDXpad1_N:[0-9]+]] = affine.apply #[[$DIV4]](%[[J2]])
   //      CHECK:       subtensor %{{.*}} [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   //      CHECK:       linalg.pad_tensor %{{.*}}
   //      CHECK:         : tensor<?x?xf32> to tensor<4x3xf32>
-  //      CHECK:       subtensor_insert %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
+  //      CHECK:       subtensor_insert %{{.*}} into %{{.*}}[%[[IDXpad1_K]], %[[IDXpad1_N]], 0, 0]
   // CHECK-SAME:         [1, 1, 4, 3] [1, 1, 1, 1] : tensor<4x3xf32> into tensor<?x?x4x3xf32>
   // 2-D loop
   //      CHECK:   scf.for %[[J:[0-9a-zA-Z]+]]
   //      CHECK:     scf.for %[[K:[0-9a-zA-Z]+]]
-  //  CHECK-NOT:       scf.for
-  //      CHECK:       %[[stA:.*]] = subtensor %[[A]][%[[K]], 0, 0] [1, 2, 4] [1, 1, 1] :
+  // Iteration count along K
+  //      CHECK:       %[[IDXpad0_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[K]])
+  //      CHECK:       %[[stA:.*]] = subtensor %[[A]][%[[IDXpad0_K]], 0, 0] [1, 2, 4] [1, 1, 1] :
   // CHECK-SAME:         tensor<?x2x4xf32> to tensor<2x4xf32>
-  //      CHECK:       %[[stB:.*]] = subtensor %[[B]][%[[K]], %[[J]], 0, 0] [1, 1, 4, 3] [1, 1, 1, 1] :
+  // Iteration count along K
+  //      CHECK:       %[[IDXpad1_K:[0-9]+]] = affine.apply #[[$DIV4]](%[[K]])
+  // Iteration count along J
+  //      CHECK:       %[[IDXpad1_N:[0-9]+]] = affine.apply #[[$DIV3]](%[[J]])
+  //      CHECK:       %[[stB:.*]] = subtensor %[[B]][%[[IDXpad1_K]], %[[IDXpad1_N]], 0, 0] [1, 1, 4, 3] [1, 1, 1, 1] :
   // CHECK-SAME:         tensor<?x?x4x3xf32> to tensor<4x3xf32>
   //      CHECK:       %[[stC:.*]] = linalg.pad_tensor %{{.*}}
   //      CHECK:         : tensor<?x?xf32> to tensor<2x3xf32>


        


More information about the Mlir-commits mailing list