[Mlir-commits] [mlir] [mlir][Interface] `DestinationStyleOpInterface`: Rename `hasTensor/BufferSemantics` (PR #77574)

Matthias Springer llvmlistbot at llvm.org
Wed Jan 10 02:01:12 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/77574

Rename interface functions as follows:
* `hasTensorSemantics` -> `hasPureTensorSemantics`
* `hasBufferSemantics` -> `hasPureBufferSemantics`

These two functions return "true" if the op has tensor/buffer operands but not buffer/tensor operands.

Add two new interface functions:
* `hasTensorSemantics`: Return "true" if the op has tensor operands. Whether the op has buffer operands or not does not matter.
* `hasBufferSemantics`: Return "true" if the op has buffer operands. Whether the op has tensor operands or not does not matter.

Also drop the "ranked" part from the interface, i.e., do not distinguish between ranked/unranked types.

This change aligns the meaning of "tensor semantics" with the bufferization framework. (An op is supposed to be bufferized if it has tensor operands, and we don't care if it also has memref operands.) This change is in preparation of #75273, which adds `BufferizableOpInterface::hasTensorSemantics`.

>From ce9ebfdec5ad23e7a5ec1b28e0cde994f8bdef3c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 10 Jan 2024 10:00:01 +0000
Subject: [PATCH] [mlir][Interface] `DestinationStyleOpInterface`: Rename
 `hasTensor/BufferSemantics`

Rename interface functions as follows:
* `hasTensorSemantics` -> `hasPureTensorSemantics`
* `hasBufferSemantics` -> `hasPureBufferSemantics`

These two functions return "true" if the op has tensor/buffer operands but not buffer/tensor operands.

Add two new interface functions:
* `hasTensorSemantics`: Return "true" if the op has tensor operands. Whether the op has buffer operands or not does not matter.
* `hasBufferSemantics`: Return "true" if the op has buffer operands. Whether the op has tensor operands or not does not matter.

Also drop the "ranked" part from the interface, i.e., do not distinguish between ranked/unranked types.

This change aligns the meaning of "tensor semantics" with the bufferization framework. (An op is supposed to be bufferized if it has tensor operands, and we don't care if it also has memref operands.) This change is in preparation of #75273, which adds `BufferizableOpInterface::hasTensorSemantics`.
---
 .../Interfaces/DestinationStyleOpInterface.td | 75 +++++++++----------
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 10 +--
 .../Transforms/BubbleUpExtractSlice.cpp       |  2 +-
 .../BufferizableOpInterfaceImpl.cpp           |  6 +-
 .../Linalg/Transforms/ConstantFold.cpp        |  2 +-
 .../Linalg/Transforms/DecomposeLinalgOps.cpp  |  2 +-
 .../Linalg/Transforms/DropUnitDims.cpp        |  2 +-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 16 ++--
 .../EraseUnusedOperandsAndResults.cpp         |  4 +-
 .../Linalg/Transforms/Generalization.cpp      |  2 +-
 .../Transforms/InlineScalarOperands.cpp       |  2 +-
 mlir/lib/Dialect/Linalg/Transforms/Loops.cpp  |  6 +-
 .../Linalg/Transforms/NamedOpConversions.cpp  |  2 +-
 .../lib/Dialect/Linalg/Transforms/Padding.cpp |  4 +-
 .../Dialect/Linalg/Transforms/Promotion.cpp   |  8 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  4 +-
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  6 +-
 .../Linalg/Transforms/Vectorization.cpp       |  2 +-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 10 +--
 .../NVGPU/Transforms/CreateAsyncGroups.cpp    |  2 +-
 .../Transforms/SparseReinterpretMap.cpp       |  4 +-
 .../Transforms/SparseTensorRewriting.cpp      | 10 +--
 .../Transforms/Sparsification.cpp             |  2 +-
 .../DestinationStyleOpInterface.cpp           |  8 +-
 .../mlir-linalg-ods-yaml-gen.cpp              |  2 +-
 25 files changed, 96 insertions(+), 97 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
index 4c52d803e11476..b1ea4c82a08c82 100644
--- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
+++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
@@ -17,24 +17,26 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
     as initial tensor values for the results of the operation or the init
     buffers to which the results of the op will be written.
 
-    Init operands must be ranked tensors or ranked memrefs. Input operands can
-    have any type. All non-init operands are DPS inputs.
+    Init operands must be tensors or memrefs. Input operands can have any type.
+    All non-init operands are DPS inputs.
 
     The init operands of this op are specified by the MutableOperandRange that
     the `getDpsInitsMutable` interface methods returns. This implies that the
     init operands must be a consecutive range of operands.
 
-    If the op has "tensor semantics", then the input operands are either ranked
-    tensors or other non-tensor/memref types ("scalars"). The init operands are
-    ranked tensors and every tensor init is tied to a corresponding tensor
-    OpResult in a 1-to-1 fashion. The i-th init tensor is tied to the i-th
-    OpResult. The op may not have any additional OpResults. Init operands and
-    their tied OpResults have the same type. Dynamic dimension sizes also match
-    at runtime.
+    Each tensor init operand is tied to a corresponding tensor OpResult in a
+    1-to-1 fashion. The i-th init tensor is tied to the i-th OpResult. The op
+    may not have any additional OpResults. Init operands and their tied
+    OpResults have the same type. Dynamic dimension sizes also match at runtime.
 
-    If the op has "buffer semantics", then the input operands are either ranked
-    memrefs or other non-tensor/memref types ("scalar" types). Furthermore, the
-    init operands are ranked memrefs and the op has no results.
+    Note: This implies that a destination style op without any tensor inits must
+    not have any OpResults.
+
+    An op has "tensor semantics" if it has at least one tensor operand.
+    An op has "buffer semantics" if it has at least one buffer (memref) operand.
+    An op has "pure tensor semantics" if it has tensor semantics but not buffer
+    semantics. An op has "pure buffer semantics" if it has buffer semantics but
+    not tensor semantics.
 
     Destination-passing style abstraction makes certain transformations easier.
     For example, tiling implementation can extract/insert slices from/into the
@@ -148,7 +150,8 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
     /// neither a MemRef nor a tensor value.
     bool isScalar(::mlir::OpOperand *opOperand) {
       assert(opOperand->getOwner() == $_op && "invalid operand");
-      return !::llvm::isa<MemRefType, TensorType>(opOperand->get().getType());
+      return !::llvm::isa<BaseMemRefType, TensorType>(
+          opOperand->get().getType());
     }
 
     /// Return the OpResult that is tied to the given OpOperand.
@@ -169,36 +172,30 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
       return $_op.getDpsInitOperand(opResult.getResultNumber());
     }
 
-    /// Return whether the op has buffer semantics. That is the case if the op
-    /// has no ranked tensor operands and at least one memref operand.
+    /// Return whether the op has buffer semantics. That is the case if the
+    /// op has at least one memref operand.
     bool hasBufferSemantics() {
-      // No tensors.
-      auto isTensor = [](Value v){
-        return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
-      };
-      if (::llvm::any_of($_op->getOperands(), isTensor))
-        return false;
-      // At least one memref.
-      auto isMemref = [](Value v){
-        return ::llvm::isa<::mlir::MemRefType>(v.getType());
-      };
-      return llvm::any_of($_op->getOperands(), isMemref);
+      return ::llvm::any_of($_op->getOperands(),
+          [](Value v) { return isa<BaseMemRefType>(v.getType()); });
     }
 
-    /// Return whether the op has tensor semantics. That is the case if the op
-    /// has no memref operands and at least one ranked tensor operand.
+    /// Return whether the op has tensor semantics. That is the case if the
+    /// op has at least one tensor operand.
     bool hasTensorSemantics() {
-      // No memrefs.
-      auto isMemref = [](Value v){
-        return ::llvm::isa<::mlir::MemRefType>(v.getType());
-      };
-      if (::llvm::any_of($_op->getOperands(), isMemref))
-        return false;
-      // At least one tensor.
-      auto isTensor = [](Value v){
-        return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
-      };
-      return llvm::any_of($_op->getOperands(), isTensor);
+      return ::llvm::any_of($_op->getOperands(),
+          [](Value v) { return isa<TensorType>(v.getType()); });
+    }
+
+    /// Return whether the op has pure buffer semantics. That is the case if the
+    /// op has no tensor operands and at least one memref operand.
+    bool hasPureBufferSemantics() {
+      return hasBufferSemantics() && !hasTensorSemantics();
+    }
+
+    /// Return whether the op has pure tensor semantics. That is the case if the
+    /// op has no memref operands and at least one tensor operand.
+    bool hasPureTensorSemantics() {
+      return hasTensorSemantics() && !hasBufferSemantics();
     }
   }];
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b68aa77fd83a1c..828a140be75456 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -550,7 +550,7 @@ struct EraseSelfCopy : OpRewritePattern<CopyOp> {
                                 PatternRewriter &rewriter) const override {
     if (copyOp.getInputs() != copyOp.getOutputs())
       return rewriter.notifyMatchFailure(copyOp, "not a self copy");
-    if (copyOp.hasBufferSemantics())
+    if (copyOp.hasPureBufferSemantics())
       rewriter.eraseOp(copyOp);
     else
       rewriter.replaceOp(copyOp, copyOp.getInputs());
@@ -1112,7 +1112,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
       return failure();
 
     // In the buffer case, we need to check exact buffer equality.
-    if (genericOp.hasBufferSemantics()) {
+    if (genericOp.hasPureBufferSemantics()) {
       if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
           genericOp.getDpsInputOperand(0)->get() ==
               genericOp.getDpsInitOperand(0)->get()) {
@@ -1123,7 +1123,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     }
 
     // Mixed semantics is not supported yet.
-    if (!genericOp.hasTensorSemantics())
+    if (!genericOp.hasPureTensorSemantics())
       return failure();
 
     // Get the argument number of the returned values. That is the operand
@@ -2257,7 +2257,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
 
   LogicalResult matchAndRewrite(LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    if (!linalgOp.hasTensorSemantics())
+    if (!linalgOp.hasPureTensorSemantics())
       return failure();
 
     // Maps must be projected permutations.
@@ -2376,7 +2376,7 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
       getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
 
   SmallVector<Type, 4> resultTypes;
-  if (hasTensorSemantics())
+  if (hasPureTensorSemantics())
     resultTypes.push_back(tiledOperands[1].getType());
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 5c4bc9137c10a8..428422e6e875a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -68,7 +68,7 @@ struct BubbleUpExtractSliceOpPattern
                                          "expected single output of linalg op");
     }
 
-    if (!linalgOp.hasTensorSemantics()) {
+    if (!linalgOp.hasPureTensorSemantics()) {
       return rewriter.notifyMatchFailure(sliceOp,
                                          "expected tensor of linalg op");
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 0577441bdd28d2..b232d56d4419f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -32,13 +32,13 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
   rewriter.setInsertionPoint(op);
 
   // Nothing to do. This op is already bufferized.
-  if (op.hasBufferSemantics())
+  if (op.hasPureBufferSemantics())
     return success();
 
   // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
   // basis.
-  if (!op.hasTensorSemantics())
-    return op->emitError() << "op does not have tensor semantics";
+  if (!op.hasPureTensorSemantics())
+    return op->emitError() << "op does not have pure tensor semantics";
 
   // New input operands for the cloned op.
   SmallVector<Value> newInputBuffers;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 062751552b3cc6..8fffabf11f3fdd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -57,7 +57,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Mixed and buffer sematics aren't supported.
-    if (!genericOp.hasTensorSemantics())
+    if (!genericOp.hasPureTensorSemantics())
       return failure();
 
     // Only support ops generating one output for now.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 28f4d8ac64431a..5cd6d4597affaf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -258,7 +258,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
   // TODO: this could be generalized to handle `linalg.generic` with buffer
   // operands too but requires allocation for intermediates. Punt on this for
   // now.
-  if (!genericOp.hasTensorSemantics()) {
+  if (!genericOp.hasPureTensorSemantics()) {
     return rewriter.notifyMatchFailure(
         genericOp, "only operations with tensor semantics are handled");
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c495956fa57702..e6f4ed5b51b1e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -83,7 +83,7 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
+    if (!genericOp.hasPureTensorSemantics())
       return failure();
     if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
       return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3eb91190751ef1..031f5c7a5d4783 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -105,7 +105,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   // Consumer can have mixed semantics, just check operand itself has tensor
   // type. Producer must have full tensor semantics to avoid potential
   // aliasing between producer and consumer memrefs.
-  if (!producer.hasTensorSemantics() ||
+  if (!producer.hasPureTensorSemantics() ||
       !isa<RankedTensorType>(fusedOperand->get().getType()))
     return false;
 
@@ -530,7 +530,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
   //   permutations.
   // - The fused tensor is not a scalar.
   // - All the loops are parallel loops.
-  return genericOp.hasTensorSemantics() &&
+  return genericOp.hasPureTensorSemantics() &&
          llvm::all_of(genericOp.getIndexingMaps().getValue(),
                       [](Attribute attr) {
                         return cast<AffineMapAttr>(attr)
@@ -1124,7 +1124,7 @@ static SmallVector<ReassociationIndices>
 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
                                  ArrayRef<ReassociationIndices> reassociation) {
   // Some basic checks for this fusion to be valid.
-  if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1)
+  if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
     return {};
 
   if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
@@ -1476,7 +1476,7 @@ Operation *createCollapsedOp(LinalgType op,
     outputOperands.push_back(newOutput);
     // If the op has "buffer semantics", then the init operands are ranked
     // memrefs and the op has no results.
-    if (!op.hasBufferSemantics())
+    if (!op.hasPureBufferSemantics())
       resultTypes.push_back(newOutput.getType());
   }
 
@@ -1521,8 +1521,8 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
       }))
     return failure();
 
-  bool hasBufferSemantics = op.hasBufferSemantics();
-  if (hasBufferSemantics &&
+  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
+  if (hasPureBufferSemantics &&
       !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
         MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
         if (!memRefToCollapse)
@@ -1705,7 +1705,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
+    if (!genericOp.hasPureTensorSemantics())
       return failure();
     for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
       Operation *def = opOperand->get().getDefiningOp();
@@ -1857,7 +1857,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
+    if (!genericOp.hasPureTensorSemantics())
       return failure();
     bool fillFound = false;
     Block &payload = genericOp.getRegion().front();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 4e54e48c914aeb..3378eda2bd6734 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -183,7 +183,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
         dedupedOutpts;
     // If the op doesn't have tensor semantics or outputs should not be removed,
     // keep all the outputs as preserved.
-    if (!genericOp.hasTensorSemantics() || !removeOutputs) {
+    if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
       for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
         origToNewPos[en.index()] = newOutputOperands.size();
         newOutputOperands.push_back(en.value().get());
@@ -317,7 +317,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
                                 PatternRewriter &rewriter) const override {
 
     // If the op doesnt have tensor semantics, preserve the outputs as is.
-    if (!genericOp.hasTensorSemantics())
+    if (!genericOp.hasPureTensorSemantics())
       return failure();
 
     bool hasRemovedCycles = false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 1d9ce4144f998d..d03d1f3a163c32 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -59,7 +59,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
   ValueRange outputs = linalgOp.getDpsInits();
   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
   SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
-  SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
+  SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
                                       ? TypeRange(ValueRange(outputs))
                                       : TypeRange{};
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index cc39fe932c24bf..34db710b1721d6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -35,7 +35,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
+    if (!genericOp.hasPureTensorSemantics())
       return failure();
 
     SmallVector<size_t> scalarOperands;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 5a56e914ea4c77..4c93da6fe9253f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -128,7 +128,7 @@ template <typename LoadOpTy, typename StoreOpTy>
 static void emitScalarImplementation(OpBuilder &b, Location loc,
                                      ArrayRef<Value> allIvs,
                                      LinalgOp linalgOp) {
-  assert(linalgOp.hasBufferSemantics() &&
+  assert(linalgOp.hasPureBufferSemantics() &&
          "expected linalg op with buffer semantics");
   SmallVector<Value> indexedValues;
   indexedValues.reserve(linalgOp->getNumOperands());
@@ -218,7 +218,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
 
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (which is asserted in the inverse calculation).
-  assert(linalgOp.hasBufferSemantics() &&
+  assert(linalgOp.hasPureBufferSemantics() &&
          "expected linalg op with buffer semantics");
 
   auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
@@ -264,7 +264,7 @@ class LinalgRewritePattern : public RewritePattern {
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     auto linalgOp = dyn_cast<LinalgOp>(op);
-    if (!isa<LinalgOp>(op) || !linalgOp.hasBufferSemantics()) {
+    if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
       return rewriter.notifyMatchFailure(
           op, "expected linalg op with buffer semantics");
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index 93fa5ff24ac6a6..250360603fa5dd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -39,7 +39,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
   Location loc = operation->getLoc();
   auto linalgOp = dyn_cast<LinalgOp>(operation);
   // Exit out on the memref version of this operation.
-  if (!linalgOp || !linalgOp.hasTensorSemantics())
+  if (!linalgOp || !linalgOp.hasPureTensorSemantics())
     return failure();
 
   auto result = operation->getResult(0);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index e6d80a39650ccf..278f3499f53e82 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -168,7 +168,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
   }
 
   // TODO: there are cases where we may still want to pad to larger sizes.
-  if (!opToPad.hasTensorSemantics())
+  if (!opToPad.hasPureTensorSemantics())
     return rewriter.notifyMatchFailure(opToPad,
                                        "expected operation on tensors");
 
@@ -265,7 +265,7 @@ mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
   assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
          "invalid options");
 
-  if (!linalgOp.hasTensorSemantics())
+  if (!linalgOp.hasPureTensorSemantics())
     return rewriter.notifyMatchFailure(
         linalgOp, "only applies to Linalg ops with tensor semantics");
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 7c8ee1727d56f8..9311fd0bb2a478 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -164,7 +164,8 @@ struct LinalgOpInstancePromotionOptions {
 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
     LinalgOp linalgOp, const LinalgPromotionOptions &options)
     : subViews(), alignment(options.alignment) {
-  assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
+  assert(linalgOp.hasPureBufferSemantics() &&
+         "revisit usage of shaped operand");
   auto vUseFullTileBuffers =
       options.useFullTileBuffers.value_or(llvm::SmallBitVector());
   vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
@@ -346,7 +347,8 @@ promoteSubViews(ImplicitLocOpBuilder &b,
 static FailureOr<LinalgOp>
 promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
                 LinalgOpInstancePromotionOptions options, DataLayout &layout) {
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+  assert(op.hasPureBufferSemantics() &&
+         "expected linalg op with buffer semantics");
 
   // 1. Promote the specified views and use them in the new op.
   auto promotedBuffersAndViews = promoteSubViews(b, options, layout);
@@ -400,7 +402,7 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options) {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   // Transformation applies to buffers only.
-  if (!linalgOp || !linalgOp.hasBufferSemantics())
+  if (!linalgOp || !linalgOp.hasPureBufferSemantics())
     return failure();
   // Check that at least one of the requested operands is indeed a subview.
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index cae7b50b0fb3b4..8b3119f02e8fda 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -212,7 +212,7 @@ struct LinalgOpTilingInterface
                                              Location loc,
                                              ValueRange ivs) const {
     auto linalgOp = cast<LinalgOp>(op);
-    if (!linalgOp.hasBufferSemantics())
+    if (!linalgOp.hasPureBufferSemantics())
       return op->emitOpError("expected operation to have buffer semantics");
 
     SmallVector<Value> indexedValues;
@@ -256,7 +256,7 @@ struct LinalgOpPartialReductionInterface
     auto linalgOp = cast<LinalgOp>(op);
     OpBuilder::InsertionGuard guard(b);
 
-    if (linalgOp.hasBufferSemantics())
+    if (linalgOp.hasPureBufferSemantics())
       return op->emitOpError("expected operation to have tensor semantics");
     // Insert the new parallel dimension based on the index of the reduction
     // loops. This could be controlled by user for more flexibility.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..51be5c9d5c573e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1344,7 +1344,7 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
 template <typename Conv2DOp, typename Conv1DOp>
 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
     returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
-  if (convOp.hasBufferSemantics())
+  if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
   Value input = convOp.getInputs().front();
@@ -1468,7 +1468,7 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
 FailureOr<DepthwiseConv1DNwcWcOp>
 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
-  if (convOp.hasBufferSemantics())
+  if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
   Value input = convOp.getInputs().front();
@@ -1536,7 +1536,7 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
 FailureOr<Conv1DOp>
 DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
                                             PatternRewriter &rewriter) const {
-  if (convOp.hasBufferSemantics())
+  if (convOp.hasPureBufferSemantics())
     return failure(); // To be implemented.
 
   Value input = convOp.getInputs().front();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5d99951ef09a92..dc348ea827cde1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -280,7 +280,7 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
                                                          operandDimPos)))
       return failure();
 
-    Value dynamicDim = linalgOp.hasTensorSemantics()
+    Value dynamicDim = linalgOp.hasPureTensorSemantics()
                            ? (Value)rewriter.create<tensor::DimOp>(
                                  linalgOp.getLoc(), operand, operandDimPos)
                            : (Value)rewriter.create<memref::DimOp>(
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 75c8cd3e1d95a1..986b5f3e1fb604 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -322,7 +322,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
          "expected as many entries for proc info as number of loops, even if "
          "they are null entries");
   SmallVector<Value> iterArgInitValues;
-  if (!linalgOp.hasBufferSemantics())
+  if (!linalgOp.hasPureBufferSemantics())
     llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
@@ -362,7 +362,7 @@ void GenerateLoopNest<AffineForOp>::doit(
         bodyBuilderFn,
     ArrayRef<linalg::ProcInfo> /*procInfo*/) {
   SmallVector<Value> iterArgInitValues;
-  if (!linalgOp.hasBufferSemantics())
+  if (!linalgOp.hasPureBufferSemantics())
     llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
   SmallVector<Value, 4> lbs, ubs, steps;
@@ -529,7 +529,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
         bodyBuilderFn,
     ArrayRef<linalg::ProcInfo> procInfo) {
   SmallVector<Value> iterArgInitValues;
-  if (!linalgOp.hasBufferSemantics())
+  if (!linalgOp.hasPureBufferSemantics())
     llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
   // This function may be passed more iterator types than ranges.
@@ -738,7 +738,7 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
 }
 
 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
-  if (op.hasBufferSemantics())
+  if (op.hasPureBufferSemantics())
     return {};
   return llvm::to_vector(
       llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
@@ -749,7 +749,7 @@ SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
                                     LinalgOp op, ValueRange operands,
                                     ValueRange results) {
-  if (op.hasBufferSemantics())
+  if (op.hasPureBufferSemantics())
     return {};
   SmallVector<Value> tensorResults;
   tensorResults.reserve(results.size());
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index f63825cdc8f617..f8c699c65fe49e 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -23,7 +23,7 @@ using namespace mlir;
 template <typename OpTy>
 static bool isContiguousXferOp(OpTy op) {
   return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
-         op.hasBufferSemantics() &&
+         op.hasPureBufferSemantics() &&
          isLastMemrefDimUnitStride(
              cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index f2e1b0bc58f132..50713be8296fa8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -373,7 +373,7 @@ struct GenericOpReinterpretMap
                           PatternRewriter &rewriter) const {
     // Only rewrite single output operations with pure (sparse) tensor
     // semantics.
-    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
         !hasAnySparseOperandOrResult(linalgOp) ||
         !hasAnyNonIdentityOperandsOrResults(linalgOp))
       return failure();
@@ -411,7 +411,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
         hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
         !hasAnySparseOperandOrResult(linalgOp)) {
       return failure();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 3b9685b8ae1e07..fa97e405584791 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -214,7 +214,7 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
+    if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
         !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
         !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
       return failure();
@@ -257,7 +257,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
   LogicalResult matchAndRewrite(GenericOp op,
                                 PatternRewriter &rewriter) const override {
     // Check consumer.
-    if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 2 ||
+    if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
         op.getNumResults() != 1 ||
         op.getNumParallelLoops() != op.getNumLoops() ||
         !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
@@ -276,7 +276,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     // Check producer.
     auto prod = dyn_cast_or_null<GenericOp>(
         op.getDpsInputOperand(other)->get().getDefiningOp());
-    if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
+    if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
         !prod.getResult(0).hasOneUse())
       return failure();
     // Sampling consumer and sum of multiplication chain producer.
@@ -407,7 +407,7 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
   LogicalResult matchAndRewrite(GenericOp op,
                                 PatternRewriter &rewriter) const override {
     // Rejects non sparse kernels.
-    if (!op.hasTensorSemantics() || !hasAnySparseOperand(op))
+    if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
       return failure();
 
     Location loc = op.getLoc();
@@ -540,7 +540,7 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
   LogicalResult matchAndRewrite(GenericOp op,
                                 PatternRewriter &rewriter) const override {
     // Reject non-reductions.
-    if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 1 ||
+    if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
         op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
       return failure();
     auto inp = op.getDpsInputOperand(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 35eb4b4f6e47f8..5834426cae2f41 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1297,7 +1297,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
   LogicalResult matchAndRewrite(linalg::GenericOp op,
                                 PatternRewriter &rewriter) const override {
     // Only accept single output operations with pure tensor semantics.
-    if (op.getNumDpsInits() != 1 || !op.hasTensorSemantics())
+    if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
       return failure();
 
     // Only accept trivial affine indices.
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index 4e5ef66887cadf..496238fcaa3ff1 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -33,12 +33,11 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
   SmallVector<OpOperand *> outputTensorOperands;
   for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
     Type type = operand.get().getType();
-    if (isa<RankedTensorType>(type)) {
+    if (isa<TensorType>(type)) {
       outputTensorOperands.push_back(&operand);
-    } else if (!isa<MemRefType>(type)) {
+    } else if (!isa<BaseMemRefType>(type)) {
       return op->emitOpError("expected that operand #")
-             << operand.getOperandNumber()
-             << " is a ranked tensor or a ranked memref";
+             << operand.getOperandNumber() << " is a tensor or a memref";
     }
   }
 
@@ -58,5 +57,6 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
              << " to match type of corresponding result (" << result.getType()
              << ")";
   }
+
   return success();
 }
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fb3c9d48f9a982..f14e559fff92f3 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -658,7 +658,7 @@ LogicalResult {0}::fold(FoldAdaptor,
 }
 void {0}::getEffects(SmallVectorImpl<
     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
-      if (hasTensorSemantics()) return;
+      if (hasPureTensorSemantics()) return;
       getGenericEffectsImpl(effects,
         getOperation()->getResults(), getDpsInputs(), getDpsInits());
 }



More information about the Mlir-commits mailing list