[Mlir-commits] [mlir] [mlir][bufferization] Fix OneShotBufferize when `defaultMemorySpaceFn` is used (PR #91524)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 8 12:31:56 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-amx
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Christopher Bate (christopherbate)
<details>
<summary>Changes</summary>
As mentioned in the issue described in issue llvm/llvm-project#<!-- -->91518, a previous
PR llvm/llvm-project#<!-- -->78484 introduced the `defaultMemorySpaceFn` into bufferization
options, allowing one to inform OneShotBufferize that it should use a specified
function to derive the memory space attribute from the encoding attribute attached
to tensor types.
However, introducing this feature exposed a unhandled edge cases, examples of which
are introduced by this change in the new test under
`test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`.
Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty
simple. This change:
- updates the `bufferization.to_memref` and `bufferization.to_tensor` operations
to explicitly include operand and destination types, whereas previously they
relied on type inference to deduce the tensor types. Since the type inference
cannot recover the correct tensor encoding/memory space, the operand and result
types must be explicitly included.
- makes minor updates to other bufferization functions to handle the
changes in building the above ops
- updates bufferization of `tensor.from_elements` to handle memory space
---
Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff
68 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h (+6)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+12-6)
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+4)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-1)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+12-1)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+12-10)
- (modified) mlir/test/Dialect/Arith/bufferize.mlir (+3-3)
- (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir (+1-1)
- (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+15-16)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir (+2-2)
- (added) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir (+133)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir (+3-3)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+3-3)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+1-1)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir (+2-2)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+1-1)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+1-1)
- (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+16-16)
- (modified) mlir/test/Dialect/Bufferization/ops.mlir (+3-3)
- (modified) mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir (+2-2)
- (modified) mlir/test/Dialect/Func/func-bufferize.mlir (+1-1)
- (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+5-5)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1)
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+2-2)
- (modified) mlir/test/Dialect/SCF/bufferize.mlir (+6-6)
- (modified) mlir/test/Dialect/Shape/bufferize.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sddmm_lib.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+14-14)
- (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+39-39)
- (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+41-41)
- (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+8-8)
- (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+11-11)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+17-17)
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+9-9)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+4-4)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+4-4)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+4-4)
- (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+6-6)
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+5-5)
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+14-14)
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+17-17)
- (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+1-1)
- (modified) mlir/test/Dialect/Vector/bufferize.mlir (+3-3)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir (+2-2)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-full.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
//===----------------------------------------------------------------------===//
// Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- TypesMatchWith<"result type matches tensor equivalent of 'memref'",
- "memref", "result",
- "memref::getTensorTypeFromMemRefType($_self)">
+ AllElementTypesMatch<["memref", "result"]>
]> {
let summary = "create a tensor from a `memref`";
let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
let assemblyFormat = [{
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref)
+ `:` type($memref) `->` type($result)
}];
+ let builders = [
+ OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+ build($_builder, $_state, rtt, memref, restrict, writeable);
+ }]>
+ ];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
Pure,
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
"memref", "tensor",
- "memref::getTensorTypeFromMemRefType($_self)">
+ "memref::getTensorTypeFromMemRefType($_self)",
+ "bufferization::detail::tensorTypesMatchUpToEncoding">
]> {
let summary = "cast a tensor to memref";
let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
/*default=*/"false",
"The memory space of an memref types must always be inferred. If "
"unset, a default memory space of 0 is used otherwise.">,
+ Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+ /*default=*/"false",
+ "Use the Tensor encoding attribute for the memory space. Exclusive to"
+ " the 'must-infer-memory-space option'">,
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
// loose all of its users and eventually DCE away.
rewriter.setInsertionPointAfter(op);
replacement = rewriter.create<bufferization::ToTensorOp>(
- replacement.getLoc(), replacement);
+ replacement.getLoc(), opResult.getType(), replacement);
}
replacements.push_back(replacement);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
// Helper functions
//===----------------------------------------------------------------------===//
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+ auto lhsType = cast<ShapedType>(lhs);
+ auto rhsType = cast<ShapedType>(rhs);
+ if (lhsType.getElementType() != rhsType.getElementType())
+ return false;
+ if (lhsType.hasRank() && rhsType.hasRank())
+ return lhsType.getShape() == rhsType.getShape();
+ return true;
+}
+
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
OpBuilder &b, Value value, MemRefType destType,
const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
+ // Unranked to ranked casts must be explicit.
+ if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+ return nullptr;
+
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
- // Unranked to ranked and ranked to unranked casts must be explicit.
+ // Ranked to unranked casts must be explicit.
auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
[](TensorType t) -> std::optional<Attribute> {
return std::nullopt;
};
+ } else if (useEncodingForMemorySpace) {
+ opt.defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> {
+ if (auto rtt = dyn_cast<RankedTensorType>(t))
+ return rtt.getEncoding();
+ return std::nullopt;
+ };
}
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
- mixedSizes, mixedStrides);
+ loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+ mixedOffsets, mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
- mixedOffsets, mixedSizes, mixedStrides));
+ extractSliceOp.getType().getShape(),
+ llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+ mixedStrides));
}
};
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
- // TODO: Implement memory space for this op.
- if (options.defaultMemorySpaceFn(tensorType) != Attribute())
- return op->emitError("memory space not implemented yet");
+ std::optional<Attribute> memorySpace =
+ options.defaultMemorySpaceFn(tensorType);
// Allocate a buffer for the result.
Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- auto memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ FailureOr<BaseMemRefType> memrefType =
+ bufferization::getBufferType(*tensorAlloc, options);
+ if (failed(memrefType))
+ return failure();
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, *tensorAlloc);
+ op->getLoc(), *memrefType, *tensorAlloc);
// Case: tensor<0xelem_type>.
if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
-// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32> to memref<index>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
// CHECK: return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
// CHECK-NEXT: return %[[clone]]
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
- %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+ %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
return %0 : memref<?x?x?xf16>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
// CHECK: return %[[ARG]] : memref<f32>
func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> tensor<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
// -----
func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (tensor<f32>) -> ()
return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// TODO: to_memref with layout maps not supported yet. This should fold to a
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
- %0 = bufferization.to_tensor %m : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
- %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
// expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
// -----
func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
- // expected-note @+1 {{prior use here}}
- %0 = bufferization.to_tensor %m : memref<*xf32>
- // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+ // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
- %m = bufferization.to_memref %t : memref<5xf32>
+ %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
- %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+ %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+ %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+ %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+ return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/91524
More information about the Mlir-commits
mailing list