[Mlir-commits] [mlir] 30e6033 - [mlir][Linalg] Add TensorsToBuffers support for Constant ops.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Oct 8 06:16:45 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-08T13:15:45Z
New Revision: 30e6033b455bfa4b888eedb2cfe808a61845ed5f

URL: https://github.com/llvm/llvm-project/commit/30e6033b455bfa4b888eedb2cfe808a61845ed5f
DIFF: https://github.com/llvm/llvm-project/commit/30e6033b455bfa4b888eedb2cfe808a61845ed5f.diff

LOG: [mlir][Linalg] Add TensorsToBuffers support for Constant ops.

This revision also inserts an end-to-end test that lowers tensors to buffers all the way to executable code on CPU.

Differential revision: https://reviews.llvm.org/D88998

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Transforms/BufferPlacement.cpp
    mlir/test/Dialect/Linalg/tensors-to-buffers.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index dcf4b5ec06cb..4e25772d5793 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -20,36 +20,39 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
 	    "Only folds the one-trip loops from Linalg ops on tensors "
 	    "(for testing purposes only)">
   ];
+  let dependentDialects = ["linalg::LinalgDialect"];
 }
 
 def LinalgFusion : FunctionPass<"linalg-fusion"> {
   let summary = "Fuse operations in the linalg dialect";
   let constructor = "mlir::createLinalgFusionPass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
 }
 
 def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
   let summary = "Fuse operations on RankedTensorType in linalg dialect";
   let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
-  let dependentDialects = ["AffineDialect"];
+  let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
 }
 
 def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
   let summary = "Lower the operations from the linalg dialect into affine "
                 "loops";
   let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
-  let dependentDialects = ["AffineDialect"];
+  let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
 }
 
 def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
   let summary = "Lower the operations from the linalg dialect into loops";
   let constructor = "mlir::createConvertLinalgToLoopsPass()";
-  let dependentDialects = ["scf::SCFDialect", "AffineDialect"];
+  let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"];
 }
 
 def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> {
   let summary = "Convert the Linalg operations which work on tensor-type "
                 "operands or results to use buffers instead";
   let constructor = "mlir::createConvertLinalgOnTensorsToBuffersPass()";
+  let dependentDialects = ["linalg::LinalgDialect", "vector::VectorDialect"];
 }
 
 def LinalgLowerToParallelLoops
@@ -57,7 +60,7 @@ def LinalgLowerToParallelLoops
   let summary = "Lower the operations from the linalg dialect into parallel "
                 "loops";
   let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
-  let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
+  let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
 }
 
 def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
@@ -69,13 +72,14 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
     Option<"useAlloca", "test-use-alloca", "bool",
            /*default=*/"false", "Test generation of alloca'ed buffers.">
   ];
+  let dependentDialects = ["linalg::LinalgDialect"];
 }
 
 def LinalgTiling : FunctionPass<"linalg-tile"> {
   let summary = "Tile operations in the linalg dialect";
   let constructor = "mlir::createLinalgTilingPass()";
   let dependentDialects = [
-    "AffineDialect", "scf::SCFDialect"
+    "AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"
   ];
   let options = [
     ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
@@ -93,7 +97,7 @@ def LinalgTilingToParallelLoops
                "Test generation of dynamic promoted buffers",
                "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
   ];
-  let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
+  let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
 }
 
 #endif // MLIR_DIALECT_LINALG_PASSES

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 2354cc6abd89..0a8e5b679277 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
 class MLIRContext;

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f8fc4dead86f..11a8cf74ddd5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
 
 def Vector_Dialect : Dialect {
   let name = "vector";
@@ -1673,7 +1674,7 @@ def Vector_BitCastOp :
 }
 
 def Vector_TypeCastOp :
-  Vector_Op<"type_cast", [NoSideEffect]>,
+  Vector_Op<"type_cast", [NoSideEffect, ViewLikeOpInterface]>,
     Arguments<(ins StaticShapeMemRefOf<[AnyType]>:$memref)>,
     Results<(outs AnyMemRef:$result)> {
   let summary = "type_cast op converts a scalar memref to a vector memref";
@@ -1711,6 +1712,8 @@ def Vector_TypeCastOp :
     MemRefType getResultMemRefType() {
       return getResult().getType().cast<MemRefType>();
     }
+    // Implement ViewLikeOpInterface.
+    Value getViewSource() { return memref(); }
   }];
 
   let assemblyFormat = [{

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
new file mode 100644
index 000000000000..a2bd18c7a3b1
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -convert-linalg-on-tensors-to-buffers -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @foo() -> tensor<4xf32> {
+  %0 = constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+func @main() {
+  %0 = call @foo() : () -> tensor<4xf32>
+
+  // Instead of relying on tensor_store which introduces aliasing, we rely on
+  // the conversion of print_memref_f32(tensor<*xf32>) to
+  // print_memref_f32(memref<*xf32>).
+  // Note that this is skipping a step and we would need at least some function
+  // attribute to declare that this conversion is valid (e.g. when we statically
+  // know that things will play nicely at the C ABI boundary).
+  %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
+  call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+
+  //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+  // CHECK-SAME: rank = 1 offset = 0 sizes = [4] strides = [1] data =
+  // CHECK-NEXT: [1, 2, 3, 4]
+
+  return
+}
+
+// This gets converted to a function operating on memref<*xf32>.
+// Note that this is skipping a step and we would need at least some function
+// attribute to declare that this conversion is valid (e.g. when we statically
+// know that things will play nicely at the C ABI boundary).
+func @print_memref_f32(%ptr : tensor<*xf32>)

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ac94a421903d..a95e1006e381 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1098,23 +1098,35 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
   bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
-/// Returns true if the memory underlying `memRefType` has a contiguous layout.
-/// Strides are written to `strides`.
-static bool isContiguous(MemRefType memRefType,
-                         SmallVectorImpl<int64_t> &strides) {
+/// Returns the strides if the memory underlying `memRefType` has a contiguous
+/// static layout.
+static llvm::Optional<SmallVector<int64_t, 4>>
+computeContiguousStrides(MemRefType memRefType) {
   int64_t offset;
-  auto successStrides = getStridesAndOffset(memRefType, strides, offset);
-  bool isContiguous = strides.empty() || strides.back() == 1;
-  if (isContiguous) {
-    auto sizes = memRefType.getShape();
-    for (int index = 0, e = strides.size() - 2; index < e; ++index) {
-      if (strides[index] != strides[index + 1] * sizes[index + 1]) {
-        isContiguous = false;
-        break;
-      }
-    }
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(memRefType, strides, offset)))
+    return None;
+  if (!strides.empty() && strides.back() != 1)
+    return None;
+  // If no layout or identity layout, this is contiguous by definition.
+  if (memRefType.getAffineMaps().empty() ||
+      memRefType.getAffineMaps().front().isIdentity())
+    return strides;
+
+  // Otherwise, we must determine contiguity form shapes. This can only ever
+  // work in static cases because MemRefType is underspecified to represent
+  // contiguous dynamic shapes in other ways than with just empty/identity
+  // layout.
+  auto sizes = memRefType.getShape();
+  for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+    if (ShapedType::isDynamic(sizes[index + 1]) ||
+        ShapedType::isDynamicStrideOrOffset(strides[index]) ||
+        ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
+      return None;
+    if (strides[index] != strides[index + 1] * sizes[index + 1])
+      return None;
   }
-  return succeeded(successStrides) && isContiguous;
+  return strides;
 }
 
 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
@@ -1150,9 +1162,17 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
       return failure();
 
-    // Only contiguous source tensors supported atm.
-    SmallVector<int64_t, 4> strides;
-    if (!isContiguous(sourceMemRefType, strides))
+    // Only contiguous source buffers supported atm.
+    auto sourceStrides = computeContiguousStrides(sourceMemRefType);
+    if (!sourceStrides)
+      return failure();
+    auto targetStrides = computeContiguousStrides(targetMemRefType);
+    if (!targetStrides)
+      return failure();
+    // Only support static strides for now, regardless of contiguity.
+    if (llvm::any_of(*targetStrides, [](int64_t stride) {
+          return ShapedType::isDynamicStrideOrOffset(stride);
+        }))
       return failure();
 
     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
@@ -1181,8 +1201,8 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
       desc.setSize(rewriter, loc, index, size);
-      auto strideAttr =
-          rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
+      auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
+                                                (*targetStrides)[index]);
       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
       desc.setStride(rewriter, loc, index, stride);
     }
@@ -1223,8 +1243,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
                                        op->getContext()))
       return failure();
     // Only contiguous source tensors supported atm.
-    SmallVector<int64_t, 4> strides;
-    if (!isContiguous(xferOp.getMemRefType(), strides))
+    auto strides = computeContiguousStrides(xferOp.getMemRefType());
+    if (!strides)
       return failure();
 
     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
@@ -1380,9 +1400,11 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern {
 
 private:
   enum class PrintConversion {
+    // clang-format off
     None,
     ZeroExt64,
     SignExt64
+    // clang-format on
   };
 
   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
index 0415aeb8a1fd..dc23a3bd8599 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
@@ -18,10 +18,18 @@ namespace mlir {
 template <typename ConcreteDialect>
 void registerDialect(DialectRegistry &registry);
 
+namespace linalg {
+class LinalgDialect;
+} // end namespace linalg
+
 namespace scf {
 class SCFDialect;
 } // end namespace scf
 
+namespace vector {
+class VectorDialect;
+} // end namespace vector
+
 #define GEN_PASS_CLASSES
 #include "mlir/Dialect/Linalg/Passes.h.inc"
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 3282358f5f41..a9ec621efb9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"
@@ -203,6 +204,64 @@ class GenericOpConverter
   }
 };
 
+// Rewrite a tensor `constant` to a vector constant folloed by a vector store
+// and a vector.type_cast.
+class TensorConstantOpConverter
+    : public BufferAssignmentOpConversionPattern<ConstantOp> {
+public:
+  using BufferAssignmentOpConversionPattern<
+      ConstantOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (!op.getType().isa<RankedTensorType>())
+      return failure();
+    auto attr = op.getValue().cast<DenseElementsAttr>();
+
+    Location loc = op.getLoc();
+    MemRefType memrefType =
+        converter->convertType(op.getType()).cast<MemRefType>();
+    VectorType vectorType =
+        VectorType::get(memrefType.getShape(), memrefType.getElementType());
+
+    // vector constant takes attributes that are compatible with tensor
+    // constant.
+    Value cstVec =
+        rewriter.create<ConstantOp>(loc, vectorType, attr.reshape(vectorType));
+
+    // Alloc a memref<vector<...>>, store the constant and typecast the vector
+    // away.
+    MemRefType memrefOfVectorType = MemRefType::get({}, vectorType);
+    Value alloc =
+        rewriter.create<AllocOp>(loc, memrefOfVectorType, ValueRange{});
+    rewriter.create<StoreOp>(loc, cstVec, alloc);
+    rewriter.replaceOpWithNewOp<vector::TypeCastOp>(op, memrefType, alloc);
+
+    return success();
+  }
+};
+
+// Rewrite a `tensor_cast` as a `memref_cast` with no layout, in the 0-memory
+// space.
+class TensorCastOpConverter
+    : public BufferAssignmentOpConversionPattern<TensorCastOp> {
+public:
+  using BufferAssignmentOpConversionPattern<
+      TensorCastOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (op.getType().hasRank())
+      return failure();
+    Type t = UnrankedMemRefType::get(op.getType().getElementType(),
+                                     /*memorySpace=*/0);
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
+    return success();
+  }
+};
+
 /// Converts Linalg operations that work on tensor-type operands or results to
 /// work on buffers.
 struct ConvertLinalgOnTensorsToBuffers
@@ -213,7 +272,7 @@ struct ConvertLinalgOnTensorsToBuffers
     BufferAssignmentTypeConverter converter;
 
     // Mark all Standard operations legal.
-    target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<ModuleTerminatorOp>();
 
@@ -225,12 +284,33 @@ struct ConvertLinalgOnTensorsToBuffers
         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
             isLegalOperation));
 
-    // Mark Standard Return operations illegal as long as one operand is tensor.
-    target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
-      return converter.isLegal(returnOp.getOperandTypes());
-    });
+    // Mark operations that consume or return tensors illegal.
+    auto isLegal = [&](Operation *op) {
+      if (llvm::any_of(op->getOperandTypes(),
+                       [&](Type t) { return !converter.isLegal(t); }))
+        return false;
+      if (llvm::any_of(op->getResultTypes(),
+                       [&](Type t) { return !converter.isLegal(t); }))
+        return false;
+      return true;
+    };
+    target.addDynamicallyLegalOp<
+        // clang-format off
+        CallOp,
+        ConstantOp,
+        ConstantIntOp,
+        ConstantIndexOp,
+        ConstantFloatOp,
+        ReturnOp,
+        TensorCastOp
+        // clang-format on
+        >(isLegal);
 
     // Mark the function operation illegal as long as an argument is tensor.
+    // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an
+    // externally defined symbol like an external library calls), only convert
+    // if some special attribute is set. This will allow more control of interop
+    // across ABI boundaries.
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
       return converter.isSignatureLegal(funcOp.getType()) &&
              llvm::none_of(funcOp.getType().getResults(),
@@ -261,5 +341,11 @@ mlir::createConvertLinalgOnTensorsToBuffersPass() {
 void mlir::linalg::populateConvertLinalgOnTensorsToBuffersPatterns(
     MLIRContext *context, BufferAssignmentTypeConverter *converter,
     OwningRewritePatternList *patterns) {
-  patterns->insert<GenericOpConverter>(context, converter);
+  patterns->insert<
+      // clang-format off
+      GenericOpConverter,
+      TensorCastOpConverter,
+      TensorConstantOpConverter
+      // clang-format on
+      >(context, converter);
 }

diff  --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 9f2c254f91e5..4f97fb836246 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -913,70 +913,75 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
 // BufferAssignmentCallOpConverter
 //===----------------------------------------------------------------------===//
 
-/// Performs the actual rewriting step.
-LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
-    CallOp callOp, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
+namespace {
+// This class represents a mapping from a result to a list of values and some
+// results that have not yet constructed. Instead, the indices of these
+// results in the operation that will be constructed are known. They will be
+// replaced with the actual values when they are available. The order of
+// adding to this mapping is important.
+class CallOpResultMapping {
+public:
+  CallOpResultMapping() { order = 0; };
 
-  // This class represents a mapping from a result to a list of values and some
-  // results that have not yet constructed. Instead, the indices of these
-  // results in the operation that will be constructed are known. They will be
-  // replaced with the actual values when they are available. The order of
-  // adding to this mapping is important.
-  class ResultMapping {
-  public:
-    ResultMapping() { order = 0; };
-
-    /// Add an available value to the mapping.
-    void addMapping(Value value) {
-      toValuesMapping.push_back({order++, value});
-    }
+  /// Add an available value to the mapping.
+  void addMapping(Value value) { toValuesMapping.push_back({order++, value}); }
 
-    /// Add the index of unavailble result value to the mapping.
-    void addMapping(unsigned index) {
-      toIndicesMapping.push_back({order++, index});
-    }
+  /// Add the index of unavailble result value to the mapping.
+  void addMapping(unsigned index) {
+    toIndicesMapping.push_back({order++, index});
+  }
 
-    /// This method returns the mapping values list. The unknown result values
-    /// that only their indicies are available are replaced with their values.
-    void getMappingValues(ValueRange valuesToReplaceIndices,
-                          SmallVectorImpl<Value> &values) {
-      // Append available values to the list.
-      SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
-                                                     toValuesMapping.end());
-      // Replace the indices with the actual values.
-      llvm::for_each(
-          toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
-            assert(entry.second < valuesToReplaceIndices.size() &&
-                   "The value index is out of range.");
-            res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
-          });
-      // Sort the values based on their adding orders.
-      llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
-                         const std::pair<unsigned, Value> &v2) {
-        return v1.first < v2.first;
-      });
-      // Fill the values.
-      llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
-        values.push_back(entry.second);
-      });
-    }
+  /// This method returns the mapping values list. The unknown result values
+  /// that only their indicies are available are replaced with their values.
+  void getMappingValues(ValueRange valuesToReplaceIndices,
+                        SmallVectorImpl<Value> &values) {
+    // Append available values to the list.
+    SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
+                                                   toValuesMapping.end());
+    // Replace the indices with the actual values.
+    llvm::for_each(
+        toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
+          assert(entry.second < valuesToReplaceIndices.size() &&
+                 "The value index is out of range.");
+          res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
+        });
+    // Sort the values based on their adding orders.
+    llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
+                       const std::pair<unsigned, Value> &v2) {
+      return v1.first < v2.first;
+    });
+    // Fill the values.
+    llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
+      values.push_back(entry.second);
+    });
+  }
 
-  private:
-    /// Keeping the inserting order of mapping values.
-    int order;
+private:
+  /// Keeping the inserting order of mapping values.
+  int order;
 
-    /// Containing the mapping values with their inserting orders.
-    SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
+  /// Containing the mapping values with their inserting orders.
+  SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
 
-    /// Containing the indices of result values with their inserting orders.
-    SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
-  };
+  /// Containing the indices of result values with their inserting orders.
+  SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
+};
+} // namespace
+
+/// Performs the actual rewriting step.
+LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
+    CallOp callOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
 
   Location loc = callOp.getLoc();
   OpBuilder builder(callOp);
   SmallVector<Value, 2> newOperands;
 
+  // TODO: if the CallOp references a FuncOp that only has a declaration (e.g.
+  // to an externally defined symbol like an external library calls), only
+  // convert if some special attribute is set.
+  // This will allow more control of interop across ABI boundaries.
+
   // Create the operands list of the new `CallOp`. It unpacks the decomposable
   // values if a decompose callback function has been provided by the user.
   for (auto operand : operands) {
@@ -989,7 +994,7 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
   // Create the new result types for the new `CallOp` and a mapping from the old
   // result to new value(s).
   SmallVector<Type, 2> newResultTypes;
-  SmallVector<ResultMapping, 4> mappings;
+  SmallVector<CallOpResultMapping, 4> mappings;
   mappings.resize(callOp.getNumResults());
   for (auto result : llvm::enumerate(callOp.getResults())) {
     SmallVector<Type, 2> originTypes;

diff  --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
index 4339b33a2379..7d714092cb7c 100644
--- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
+++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
@@ -123,3 +123,53 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
 // CHECK: linalg.copy(%[[OUT_BUF_2]], %[[OUT_2]]) : [[TYPE]], [[TYPE]]
 // CHECK: dealloc %[[OUT_BUF_2]] : [[TYPE]]
 // CHECK: return
+
+// -----
+
+func @foo() -> tensor<4xf32> {
+// CHECK-LABEL: func @foo(
+//  CHECK-SAME:   %[[A:[0-9a-z]*]]: memref<4xf32>) {
+
+  %0 = constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32>
+//  CHECK-NEXT:   %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<4xf32>
+//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<vector<4xf32>>
+//  CHECK-NEXT:   store %[[CST]], %[[ALLOC]][] : memref<vector<4xf32>>
+//  CHECK-NEXT:   %[[RES:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<4xf32>> to memref<4xf32>
+
+  return %0 : tensor<4xf32>
+//  CHECK-NEXT:   linalg.copy(%[[RES]], %[[A]]) : memref<4xf32>, memref<4xf32>
+//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<vector<4xf32>>
+//  CHECK-NEXT:   return
+}
+
+func @bar() {
+// CHECK-LABEL: func @bar() {
+
+  %0 = call @foo() : () -> tensor<4xf32>
+//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<4xf32>
+//  CHECK-NEXT:   call @foo(%[[ALLOC]]) : (memref<4xf32>) -> ()
+
+  // Instead of relying on tensor_store which introduces aliasing, we rely on
+  // the conversion of print_memref_f32(tensor<*xf32>) to
+  // print_memref_f32(memref<*xf32>).
+  // Note that this is skipping a step and we would need at least some function
+  // attribute to declare that this conversion is valid (e.g. when we statically
+  // know that things will play nicely at the C ABI boundary).
+  %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
+//  CHECK-NEXT:   %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] :
+//  CHECK-SAME:     memref<4xf32> to memref<*xf32>
+
+  call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+//  CHECK-NEXT:   call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> ()
+
+  return
+//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<4xf32>
+//  CHECK-NEXT:   return
+}
+
+// This gets converted to a function operating on memref<*xf32>.
+// Note that this is skipping a step and we would need at least some function
+// attribute to declare that this conversion is valid (e.g. when we statically
+// know that things will play nicely at the C ABI boundary).
+func @print_memref_f32(%ptr : tensor<*xf32>)
+// CHECK-LABEL: func @print_memref_f32(memref<*xf32>)


        


More information about the Mlir-commits mailing list