[Mlir-commits] [mlir] 71bbb78 - [mlir][linalg][bufferize] Support tensor.generate

Matthias Springer llvmlistbot at llvm.org
Tue Jan 25 05:19:37 PST 2022


Author: Matthias Springer
Date: 2022-01-25T22:19:22+09:00
New Revision: 71bbb78b8fdc72732e3c21ee6d37f3c3868a7fdc

URL: https://github.com/llvm/llvm-project/commit/71bbb78b8fdc72732e3c21ee6d37f3c3868a7fdc
DIFF: https://github.com/llvm/llvm-project/commit/71bbb78b8fdc72732e3c21ee6d37f3c3868a7fdc.diff

LOG: [mlir][linalg][bufferize] Support tensor.generate

This is mostly a copy of the existing tensor.generate bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted.

Differential Revision: https://reviews.llvm.org/D117770

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ea9d885736f90..aaa89b4ae2242 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
@@ -228,6 +229,65 @@ struct ExtractOpInterface
   }
 };
 
+/// Bufferization of tensor.generate.
+struct GenerateOpInterface
+    : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
+                                                    tensor::GenerateOp> {
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto generateOp = cast<tensor::GenerateOp>(op);
+
+    // Allocate memory.
+    Location loc = op->getLoc();
+    MemRefType memrefType =
+        getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
+    FailureOr<Value> maybeResult =
+        createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
+                    /*deallocMemref=*/state.getOptions().createDeallocs,
+                    state.getOptions());
+    if (failed(maybeResult))
+      return failure();
+    Value result = *maybeResult;
+
+    // Collect loop bounds.
+    int64_t rank = memrefType.getRank();
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    SmallVector<Value, 4> lowerBounds(rank, zero);
+    SmallVector<Value, 4> steps(rank, one);
+    SmallVector<Value, 4> upperBounds;
+    int nextDynamicIndex = 0;
+    for (int i = 0; i < rank; i++) {
+      Value upperBound = memrefType.isDynamicDim(i)
+                             ? generateOp.dynamicExtents()[nextDynamicIndex++]
+                             : rewriter.create<arith::ConstantIndexOp>(
+                                   loc, memrefType.getDimSize(i));
+      upperBounds.push_back(upperBound);
+    }
+
+    // Generate tensor elements with a parallel loop that stores into
+    // each element of the resulting memref. We use mergeBlockBefore to "move"
+    // this op's body into the scf.parallel's body.
+    auto parallel =
+        rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
+    Block *parallelBody = parallel.getBody();
+    rewriter.mergeBlockBefore(generateOp.getBody(),
+                              parallelBody->getTerminator(),
+                              parallelBody->getArguments());
+    // Replace the inlined yield op with a store op. The scf.parallel's builder
+    // already populated an scf.yield at the end, so we don't need to worry
+    // about creating that.
+    Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
+    rewriter.setInsertionPointAfter(elementYield);
+    rewriter.replaceOpWithNewOp<memref::StoreOp>(
+        elementYield, elementYield->getOperands()[0], result,
+        parallelBody->getArguments());
+
+    replaceOpWithBufferizedValues(rewriter, op, result);
+    return success();
+  }
+};
+
 /// Bufferization of tensor.insert. Replace with memref.store.
 struct InsertOpInterface
     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
@@ -502,6 +562,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
   registry.addOpInterface<DimOp, DimOpInterface>();
   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
+  registry.addOpInterface<GenerateOp, GenerateOpInterface>();
   registry.addOpInterface<InsertOp, InsertOpInterface>();
   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
   registry.addOpInterface<RankOp, RankOpInterface>();

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 28ee8bea2e9ec..eacb2bb9314dc 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1359,3 +1359,23 @@ func @tensor_rank(%arg0: tensor<*xf32>) -> index {
   // CHECK: return %[[r]] : index
   return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
+//  CHECK-SAME:     %[[arg0:.*]]: index
+func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
+  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+  // CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
+  // CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
+  %result = tensor.generate %arg0 {
+  ^bb0(%i: index, %j: index):
+    %sum = arith.addi %i, %j : index
+    // CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
+    // CHECK: scf.yield
+    tensor.yield %sum : index
+  } : tensor<16x?xindex>
+  // CHECK: }
+  return %result : tensor<16x?xindex>
+}


        


More information about the Mlir-commits mailing list