[Mlir-commits] [mlir] 9ee12f4 - [mlir][tensor][bufferize] Bufferize tensor.pad

Matthias Springer llvmlistbot at llvm.org
Mon Aug 22 08:05:35 PDT 2022


Author: Matthias Springer
Date: 2022-08-22T17:00:33+02:00
New Revision: 9ee12f47785929acaa5f71d0ae51e08f0f3acbab

URL: https://github.com/llvm/llvm-project/commit/9ee12f47785929acaa5f71d0ae51e08f0f3acbab
DIFF: https://github.com/llvm/llvm-project/commit/9ee12f47785929acaa5f71d0ae51e08f0f3acbab.diff

LOG: [mlir][tensor][bufferize] Bufferize tensor.pad

tensor.pad is lowered to tensor.generate + tensor.insert_slice during bufferization. For best performance with constant padding values, users should vectorize the IR before bufferizing it.

This change also relaxes tje restriction that no new ops that bufferize to a memory write should be added during bufferization. Since bufferization has been split into two steps a while ago (tensor copy insertion + bufferization), it is reasonable to allow this now.

Differential Revision: https://reviews.llvm.org/D132355

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index e6d38953ba638..a233367b38bf1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -351,17 +351,6 @@ class BufferizationRewriter : public IRRewriter {
     if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
       return;
 
-#ifndef NDEBUG
-    // Read-only tensor ops may be created during bufferization. Ops that are
-    // writing should not be created because such ops were never analyzed.
-    // Bufferizing such ops could introduce a RaW conflict.
-    for (OpOperand &operand : op->getOpOperands())
-      if (operand.get().getType().isa<TensorType>())
-        assert(!analysisState.bufferizesToMemoryWrite(operand) &&
-               "creating tensor ops that bufferize to a memory write is not "
-               "allowed during bufferization");
-#endif // NDEBUG
-
     // Add op to worklist.
     worklist.push_back(op);
   }

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 38044da2e4dbc..881237c499ed8 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 
@@ -739,6 +740,92 @@ struct InsertSliceOpInterface
   }
 };
 
+/// Bufferization of tensor.pad. Replace with tensor.generate + insert_slice.
+/// For best performance, vectorize before bufferization (better performance in
+/// case of padding with a constant).
+struct PadOpInterface
+    : public BufferizableOpInterface::ExternalModel<PadOpInterface,
+                                                    tensor::PadOp> {
+  bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return false;
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    return {};
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto padOp = cast<tensor::PadOp>(op);
+    Location loc = padOp.getLoc();
+    RankedTensorType resultType = padOp.getResultType();
+    RankedTensorType srcType = padOp.getSourceType();
+
+    auto toValue = [&](OpFoldResult ofr) {
+      if (ofr.is<Value>())
+        return ofr.get<Value>();
+      return rewriter
+          .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
+          .getResult();
+    };
+
+    // Compute dynamic result dimensions.
+    SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
+    SmallVector<Value> dynamicSizes;
+    for (int64_t i = 0; i < resultType.getRank(); ++i) {
+      if (!resultType.isDynamicDim(i))
+        continue;
+      Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
+      Value lowPad = toValue(mixedLowPad[i]);
+      Value highPad = toValue(mixedHighPad[i]);
+      Value s1 = rewriter.create<arith::AddIOp>(loc, lowPad, highPad);
+      Value s2 = rewriter.create<arith::AddIOp>(loc, s1, srcDim);
+      dynamicSizes.push_back(s2);
+    }
+
+    // Create tensor::GenerateOp.
+    auto generateOp =
+        rewriter.create<tensor::GenerateOp>(loc, resultType, dynamicSizes);
+    // Move over "escape" attribute if present.
+    if (padOp->hasAttr(BufferizationDialect::kEscapeAttrName))
+      generateOp->setAttr(
+          BufferizationDialect::kEscapeAttrName,
+          padOp->getAttr(BufferizationDialect::kEscapeAttrName));
+    // TODO: Memory space
+    rewriter.inlineRegionBefore(padOp.getRegion(), generateOp.getBody(),
+                                generateOp.getBody().begin());
+
+    // Create tensor::InsertSliceOp.
+    SmallVector<OpFoldResult> sliceSizes, sliceStrides;
+    for (int64_t i = 0; i < resultType.getRank(); ++i) {
+      sliceStrides.push_back(rewriter.getIndexAttr(1));
+      if (srcType.isDynamicDim(i)) {
+        Value size = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
+        sliceSizes.push_back(size);
+      } else {
+        sliceSizes.push_back(rewriter.getIndexAttr(srcType.getDimSize(i)));
+      }
+    }
+    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+        padOp, padOp.getSource(), generateOp.getResult(),
+        /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
+
+    return success();
+  }
+};
+
 /// Bufferization of tensor.rank. Replace with memref.rank.
 struct RankOpInterface
     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
@@ -982,6 +1069,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
     InsertOp::attachInterface<InsertOpInterface>(*ctx);
     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
+    PadOp::attachInterface<PadOpInterface>(*ctx);
     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
         *ctx);
     RankOp::attachInterface<RankOpInterface>(*ctx);

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 937588e045bba..8d53585cb1d8c 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -544,3 +544,36 @@ func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
   // CHECK: return %[[r]]
   return %reshaped : tensor<2x2x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor.pad(
+//  CHECK-SAME:   %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
+func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
+                      %h2: index) -> tensor<?x?xindex> {
+  // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xindex>
+  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
+  // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
+  // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
+  // CHECK-DAG: %[[pad0:.*]] = arith.addi %[[c5]], %[[h1]]
+  // CHECK-DAG: %[[size0:.*]] = arith.addi %[[pad0]], %[[dim0]]
+  // CHECK-DAG: %[[pad1:.*]] = arith.addi %[[l2]], %[[h2]]
+  // CHECK-DAG: %[[size1:.*]] = arith.addi %[[pad1]], %[[dim1]]
+  // CHECK:     %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex>
+  // CHECK:     scf.parallel ({{.*}}) = (%[[c0]], %[[c0]]) to (%[[size0]], %[[size1]]) step (%[[c1]], %[[c1]]) {
+  // CHECK:       memref.store
+  // CHECK:     }
+  // CHECK:     %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1]
+  // CHECK:     memref.copy %[[m1]], %[[subview]]
+  %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] {
+  ^bb0(%arg0: index, %arg1: index):
+    %m = arith.muli %arg0, %arg1 : index
+    tensor.yield %m : index
+  } : tensor<?x10xindex> to tensor<?x?xindex>
+
+  // CHECK:     %[[r:.*]] = bufferization.to_tensor %[[alloc]]
+  // CHECK:     return %[[r]] : tensor<?x?xindex>
+  return %0 : tensor<?x?xindex>
+}

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 589bfcd4aac57..220d18d2011c7 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -236,3 +236,21 @@ func.func @dealloc_generate_buffer(%arg: tensor<*xf32>, %sz: index, %idx: index)
   %r = tensor.extract %0[%idx] : tensor<?xindex>
   return %r : index
 }
+
+// -----
+
+// CHECK-LABEL: func @dealloc_pad_buffer
+func.func @dealloc_pad_buffer(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
+                              %h2: index, %idx: index) -> index {
+  // CHECK: memref.alloc
+  // CHECK: scf.parallel
+  // CHECK: memref.load
+  // CHECK: memref.dealloc
+  %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] {
+  ^bb0(%arg0: index, %arg1: index):
+    %m = arith.muli %arg0, %arg1 : index
+    tensor.yield %m : index
+  } : tensor<?x10xindex> to tensor<?x?xindex>
+  %r = tensor.extract %0[%idx, %idx] : tensor<?x?xindex>
+  return %r : index
+}


        


More information about the Mlir-commits mailing list