[Mlir-commits] [mlir] [mlir][bufferization] Fix OneShotBufferize when `defaultMemorySpaceFn` is used (PR #91524)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 8 12:31:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-func

Author: Christopher Bate (christopherbate)

<details>
<summary>Changes</summary>

As mentioned in the issue described in issue llvm/llvm-project#<!-- -->91518, a previous
PR llvm/llvm-project#<!-- -->78484 introduced the `defaultMemorySpaceFn` into bufferization
options, allowing one to inform OneShotBufferize that it should use a specified
function to derive the memory space attribute from the encoding attribute attached
to tensor types.

However, introducing this feature exposed a unhandled edge cases, examples of which
are introduced by this change in the new test under
`test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`.

Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty
simple. This change:

- updates the `bufferization.to_memref` and `bufferization.to_tensor` operations
  to explicitly include operand and destination types, whereas previously they
  relied on type inference to deduce the tensor types. Since the type inference
  cannot recover the correct tensor encoding/memory space, the operand and result
  types must be explicitly included.
- makes minor updates to other bufferization functions to handle the
  changes in building the above ops
- updates bufferization of `tensor.from_elements` to handle memory space


---

Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff


68 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h (+6) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+12-6) 
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+4) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+12-1) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+12-10) 
- (modified) mlir/test/Dialect/Arith/bufferize.mlir (+3-3) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir (+1-1) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+15-16) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir (+2-2) 
- (added) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir (+133) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir (+3-3) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+3-3) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+1-1) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir (+2-2) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+16-16) 
- (modified) mlir/test/Dialect/Bufferization/ops.mlir (+3-3) 
- (modified) mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir (+2-2) 
- (modified) mlir/test/Dialect/Func/func-bufferize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+5-5) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1) 
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+2-2) 
- (modified) mlir/test/Dialect/SCF/bufferize.mlir (+6-6) 
- (modified) mlir/test/Dialect/Shape/bufferize.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sddmm_lib.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+14-14) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+39-39) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+41-41) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+8-8) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+11-11) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+17-17) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+9-9) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+4-4) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+4-4) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+4-4) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+6-6) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+5-5) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+14-14) 
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+17-17) 
- (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Vector/bufferize.mlir (+3-3) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-full.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
 
 //===----------------------------------------------------------------------===//
 // Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    TypesMatchWith<"result type matches tensor equivalent of 'memref'",
-                   "memref", "result",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+    AllElementTypesMatch<["memref", "result"]>
   ]> {
   let summary = "create a tensor from a `memref`";
   let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
   let assemblyFormat = [{
     $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref)
+      `:` type($memref) `->` type($result)
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+      build($_builder, $_state, rtt, memref, restrict, writeable);
+    }]>
+  ];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     Pure,
     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
                    "memref", "tensor",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+                   "memref::getTensorTypeFromMemRefType($_self)",
+                   "bufferization::detail::tensorTypesMatchUpToEncoding">
   ]> {
   let summary = "cast a tensor to memref";
   let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
            /*default=*/"false",
            "The memory space of an memref types must always be inferred. If "
            "unset, a default memory space of 0 is used otherwise.">,
+    Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+            /*default=*/"false",
+            "Use the Tensor encoding attribute for the memory space. Exclusive to"
+            " the 'must-infer-memory-space option'">,
     Option<"testAnalysisOnly", "test-analysis-only", "bool",
             /*default=*/"false",
            "Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
       // loose all of its users and eventually DCE away.
       rewriter.setInsertionPointAfter(op);
       replacement = rewriter.create<bufferization::ToTensorOp>(
-          replacement.getLoc(), replacement);
+          replacement.getLoc(), opResult.getType(), replacement);
     }
     replacements.push_back(replacement);
   }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
 // Helper functions
 //===----------------------------------------------------------------------===//
 
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+  auto lhsType = cast<ShapedType>(lhs);
+  auto rhsType = cast<ShapedType>(rhs);
+  if (lhsType.getElementType() != rhsType.getElementType())
+    return false;
+  if (lhsType.hasRank() && rhsType.hasRank())
+    return lhsType.getShape() == rhsType.getShape();
+  return true;
+}
+
 FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     OpBuilder &b, Value value, MemRefType destType,
     const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
+    // Unranked to ranked casts must be explicit.
+    if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+      return nullptr;
+
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
       // MemRef to MemRef cast.
       assert(inputType != type && "expected different types");
-      // Unranked to ranked and ranked to unranked casts must be explicit.
+      // Ranked to unranked casts must be explicit.
       auto rankedDestType = dyn_cast<MemRefType>(type);
       if (!rankedDestType)
         return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
             [](TensorType t) -> std::optional<Attribute> {
           return std::nullopt;
         };
+      } else if (useEncodingForMemorySpace) {
+        opt.defaultMemorySpaceFn =
+            [](TensorType t) -> std::optional<Attribute> {
+          if (auto rtt = dyn_cast<RankedTensorType>(t))
+            return rtt.getEncoding();
+          return std::nullopt;
+        };
       }
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
-        mixedSizes, mixedStrides);
+        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+        mixedOffsets, mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
-        mixedOffsets, mixedSizes, mixedStrides));
+        extractSliceOp.getType().getShape(),
+        llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+        mixedStrides));
   }
 };
 
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
-    // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
-      return op->emitError("memory space not implemented yet");
+    std::optional<Attribute> memorySpace =
+        options.defaultMemorySpaceFn(tensorType);
 
     // Allocate a buffer for the result.
     Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    auto memrefType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+    FailureOr<BaseMemRefType> memrefType =
+        bufferization::getBufferType(*tensorAlloc, options);
+    if (failed(memrefType))
+      return failure();
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, *tensorAlloc);
+        op->getLoc(), *memrefType, *tensorAlloc);
 
     // Case: tensor<0xelem_type>.
     if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
   %index_scalar = arith.index_cast %scalar : i32 to index
   return %index_tensor, %index_scalar : tensor<index>, index
 }
-// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
 // CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
 // CHECK-SAME:   memref<i32> to memref<index>
 // CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
 // CHECK-SAME:                 %[[PRED:.*]]: i1,
 // CHECK-SAME:                 %[[TRUE_VAL:.*]]: tensor<f32>,
 // CHECK-SAME:                 %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
 // CHECK:           %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
 // CHECK:           return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
 //  CHECK-NEXT:   %[[clone:.*]] = bufferization.clone %[[m]]
 //  CHECK-NEXT:   return %[[clone]]
 func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
-  %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+  %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
   return %0 : memref<?x?x?xf16>
 }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
 // CHECK:           return %[[ARG]] : memref<f32>
 func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
 func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
   // expected-error @+1 {{failed to legalize operation 'test.source'}}
   %0 = "test.source"() : () -> tensor<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
 // -----
 
 func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
   // expected-error @+1 {{failed to legalize operation 'test.sink'}}
   "test.sink"(%0) : (tensor<f32>) -> ()
   return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
 // TODO: to_memref with layout maps not supported yet. This should fold to a
 // memref.cast.
 func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
-  %0 = bufferization.to_tensor %m : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
   // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
-  %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
   // expected-note @+1 {{see existing live user here}}
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
 // -----
 
 func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
-  // expected-note @+1 {{prior use here}}
-  %0 = bufferization.to_tensor %m : memref<*xf32>
-  // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+  // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
-  %m = bufferization.to_memref %t : memref<5xf32>
+  %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
-  %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+  %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+  %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+  %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+  return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+//  CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+//       CHECK:     %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+//       CHECK-DAG:     %[[c0:.+]] = arith.constant 0 : index
+//       CHECK-DAG:     %[[c1:.+]] = arith.constant 1 : index
+//       CHECK-DAG:     %[[c2:.+]] = arith.constant 2 : index
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/91524


More information about the Mlir-commits mailing list