[Mlir-commits] [mlir] [mlir][bufferization] Use Type instead of Value in unknown conversion (PR #144658)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 18 02:21:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Andrei Golubev (andrey-golubev)

<details>
<summary>Changes</summary>

Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function.

Both changes are prerequisites to enable custom types support in one-shot bufferization.

---
Full diff: https://github.com/llvm/llvm-project/pull/144658.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+5-4) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-10) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..2fb795f16ae2c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -265,9 +265,9 @@ struct BufferizationOptions {
       std::function<BaseMemRefType(TensorType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
   /// Tensor -> MemRef type converter.
-  /// Parameters: Value, memory space, bufferization options
+  /// Parameters: tensor type, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
-      Value, Attribute memorySpace, const BufferizationOptions &)>;
+      TensorType, Attribute memorySpace, const BufferizationOptions &)>;
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
   return newOp;
 }
 
-/// Return a MemRefType to which the type of the given value can be bufferized.
+/// Return a MemRefType to which the TensorType can be bufferized.
 ///
 /// If possible, op bufferization implementations should not use this function
 /// and instead infer precise memref types for tensor results by themselves.
@@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
 /// Note: Canonicalization patterns could clean up layout maps and infer more
 /// precise layout maps after bufferization. However, many possible
 /// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
+BaseMemRefType getMemRefType(TensorType tensorType,
+                             const BufferizationOptions &options,
                              MemRefLayoutAttrInterface layout = {},
                              Attribute memorySpace = nullptr);
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..dd43647682ea2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
 BaseMemRefType
-defaultUnknownTypeConverter(Value value, Attribute memorySpace,
+defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
                             const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(
-      llvm::cast<TensorType>(value.getType()), memorySpace);
+  return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
 }
 
 } // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return getMemRefType(cast<TensorType>(value.getType()), options,
+                       /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
 // Bufferization-specific IRMapping support with debugging.
 //===----------------------------------------------------------------------===//
 
-BaseMemRefType bufferization::getMemRefType(Value value,
+BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
                                             const BufferizationOptions &options,
                                             MemRefLayoutAttrInterface layout,
                                             Attribute memorySpace) {
-  auto tensorType = llvm::cast<TensorType>(value.getType());
-
   // Case 1: Unranked memref type.
   if (auto unrankedTensorType =
           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
                            memorySpace);
   }
 
-  return options.unknownTypeConverterFn(value, memorySpace, options);
+  return options.unknownTypeConverterFn(tensorType, memorySpace, options);
 }
 
 BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+  auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
   if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(value, options);
+    return bufferization::getMemRefType(tensorType, options);
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..7e9b9119ce949 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
                   "'unknown-type-conversion'");
         return signalPassFailure();
       }
-      opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
+      opt.unknownTypeConverterFn = [=](TensorType tensorType,
+                                       Attribute memorySpace,
                                        const BufferizationOptions &options) {
-        auto tensorType = cast<TensorType>(value.getType());
         if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
           return bufferization::getMemRefTypeWithStaticIdentityLayout(
               tensorType, memorySpace);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index a3ab53d818115..15e5102462ad7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
   OneShotBufferizationOptions options;
   options.bufferizeFunctionBoundaries = true;
   options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
-  options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+  options.unknownTypeConverterFn = [](TensorType tensorType,
+                                      Attribute memorySpace,
                                       const BufferizationOptions &options) {
-    return getMemRefTypeWithStaticIdentityLayout(
-        cast<TensorType>(value.getType()), memorySpace);
+    return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
   };
   if (analysisOnly) {
     options.testAnalysisOnly = true;

``````````

</details>


https://github.com/llvm/llvm-project/pull/144658


More information about the Mlir-commits mailing list