[Mlir-commits] [mlir] [mlir][side effect] refactor(*): Include more precise side effects (PR #94213)

donald chen llvmlistbot at llvm.org
Mon Jun 3 05:50:12 PDT 2024


https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/94213

This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperands the operation reads or writes, rather than just recording the reading and writing of values.

>From 5a7765da1d07e6b75d97f8883c05fb26c1e3d239 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Mon, 3 Jun 2024 12:42:47 +0000
Subject: [PATCH] [mlir][side effect] refactor(*): Include more precise side
 effects

This patch adds more precise side effects to the current ops with memory
effects, allowing us to determine which OpOperands the operation reads or
writes, rather than just recording the reading and writing of values.
---
 .../mlir/Dialect/Affine/IR/AffineOps.h        | 22 ++++++++++---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 13 +++++---
 .../mlir/Interfaces/SideEffectInterfaces.h    | 32 +++++++++++++++++--
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  8 ++---
 .../Bufferization/IR/BufferizationOps.cpp     |  2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |  4 +--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 29 +++++++++++------
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  4 +--
 8 files changed, 86 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index f070d04886190..5c75e102c3d40 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -107,6 +107,9 @@ class AffineDmaStartOp
 
   /// Returns the source MemRefType for this DMA operation.
   Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
+  OpOperand &getSrcMemRefMutable() {
+    return getOperation()->getOpOperand(getSrcMemRefOperandIndex());
+  }
   MemRefType getSrcMemRefType() {
     return cast<MemRefType>(getSrcMemRef().getType());
   }
@@ -117,7 +120,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the source memref.
   AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
   AffineMapAttr getSrcMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getSrcMapAttrStrName()));
   }
 
   /// Returns the source memref affine map indices for this DMA operation.
@@ -139,6 +143,9 @@ class AffineDmaStartOp
 
   /// Returns the destination MemRefType for this DMA operation.
   Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
+  OpOperand &getDstMemRefMutable() {
+    return getOperation()->getOpOperand(getDstMemRefOperandIndex());
+  }
   MemRefType getDstMemRefType() {
     return cast<MemRefType>(getDstMemRef().getType());
   }
@@ -156,7 +163,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the destination memref.
   AffineMap getDstMap() { return getDstMapAttr().getValue(); }
   AffineMapAttr getDstMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getDstMapAttrStrName()));
   }
 
   /// Returns the destination memref indices for this DMA operation.
@@ -173,6 +181,9 @@ class AffineDmaStartOp
 
   /// Returns the Tag MemRef for this DMA operation.
   Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
+  OpOperand &getTagMemRefMutable() {
+    return getOperation()->getOpOperand(getTagMemRefOperandIndex());
+  }
   MemRefType getTagMemRefType() {
     return cast<MemRefType>(getTagMemRef().getType());
   }
@@ -185,7 +196,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref indices for this DMA operation.
@@ -300,6 +312,7 @@ class AffineDmaWaitOp
 
   /// Returns the Tag MemRef associated with the DMA operation being waited on.
   Value getTagMemRef() { return getOperand(0); }
+  OpOperand &getTagMemRefMutable() { return getOperation()->getOpOperand(0); }
   MemRefType getTagMemRefType() {
     return cast<MemRefType>(getTagMemRef().getType());
   }
@@ -307,7 +320,8 @@ class AffineDmaWaitOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref index for this DMA operation.
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63e6ed059deb1..0606bfd28503a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -706,6 +706,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
   let extraClassDeclaration = [{
     // Returns the source MemRefType for this DMA operation.
     Value getSrcMemRef() { return getOperand(0); }
+    OpOperand &getSrcMemRefMutable() { return getOperation()->getOpOperand(0); }
     // Returns the rank (number of indices) of the source MemRefType.
     unsigned getSrcMemRefRank() {
       return ::llvm::cast<MemRefType>(getSrcMemRef().getType()).getRank();
@@ -718,6 +719,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
 
     // Returns the destination MemRefType for this DMA operations.
     Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
+    OpOperand &getDstMemRefMutable() { return getOperation()->getOpOperand(1 + getSrcMemRefRank()); }
     // Returns the rank (number of indices) of the destination MemRefType.
     unsigned getDstMemRefRank() {
       return ::llvm::cast<MemRefType>(getDstMemRef().getType()).getRank();
@@ -745,6 +747,9 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
     Value getTagMemRef() {
       return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
     }
+    OpOperand &getTagMemRefMutable() {
+      return getOperation()->getOpOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
+    }
 
     // Returns the rank (number of indices) of the tag MemRefType.
     unsigned getTagMemRefRank() {
@@ -801,11 +806,11 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
     void getEffects(
         SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
         effects) {
-      effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
+      effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
                            SideEffects::DefaultResource::get());
-      effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
+      effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
                            SideEffects::DefaultResource::get());
-      effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+      effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
                            SideEffects::DefaultResource::get());
     }
   }];
@@ -852,7 +857,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
     void getEffects(
         SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
         effects) {
-      effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+      effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
                            SideEffects::DefaultResource::get());
     }
   }];
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index ec4e36263bbe6..61af0acfb986e 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -149,11 +149,20 @@ class EffectInstance {
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value), stage(0),
         effectOnFullRegion(false) {}
+  EffectInstance(EffectT *effect, OpOperand *opd,
+                 Resource *resource = DefaultResource::get())
+      : effect(effect), resource(resource), value(opd), stage(0),
+        effectOnFullRegion(false) {}
   EffectInstance(EffectT *effect, Value value, int stage,
                  bool effectOnFullRegion,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value), stage(stage),
         effectOnFullRegion(effectOnFullRegion) {}
+  EffectInstance(EffectT *effect, OpOperand *opd, int stage,
+                 bool effectOnFullRegion,
+                 Resource *resource = DefaultResource::get())
+      : effect(effect), resource(resource), value(opd), stage(stage),
+        effectOnFullRegion(effectOnFullRegion) {}
   EffectInstance(EffectT *effect, SymbolRefAttr symbol,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(symbol), stage(0),
@@ -176,12 +185,21 @@ class EffectInstance {
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value),
         parameters(parameters), stage(0), effectOnFullRegion(false) {}
+  EffectInstance(EffectT *effect, OpOperand *opd, Attribute parameters,
+                 Resource *resource = DefaultResource::get())
+      : effect(effect), resource(resource), value(opd), parameters(parameters),
+        stage(0), effectOnFullRegion(false) {}
   EffectInstance(EffectT *effect, Value value, Attribute parameters, int stage,
                  bool effectOnFullRegion,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value),
         parameters(parameters), stage(stage),
         effectOnFullRegion(effectOnFullRegion) {}
+  EffectInstance(EffectT *effect, OpOperand *opd, Attribute parameters,
+                 int stage, bool effectOnFullRegion,
+                 Resource *resource = DefaultResource::get())
+      : effect(effect), resource(resource), value(opd), parameters(parameters),
+        stage(stage), effectOnFullRegion(effectOnFullRegion) {}
   EffectInstance(EffectT *effect, SymbolRefAttr symbol, Attribute parameters,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(symbol),
@@ -199,7 +217,17 @@ class EffectInstance {
   /// Return the value the effect is applied on, or nullptr if there isn't a
   /// known value being affected.
   Value getValue() const {
-    return value ? llvm::dyn_cast_if_present<Value>(value) : Value();
+    if (!value || llvm::isa_and_present<SymbolRefAttr>(value)) {
+      return Value();
+    }
+    if (Value v = llvm::dyn_cast_if_present<Value>(value)) {
+      return v;
+    }
+    return cast_if_present<OpOperand *>(value)->get();
+  }
+
+  OpOperand *getOpOperand() const {
+    return value ? dyn_cast_if_present<OpOperand *>(value) : nullptr;
   }
 
   /// Return the symbol reference the effect is applied on, or nullptr if there
@@ -229,7 +257,7 @@ class EffectInstance {
   Resource *resource;
 
   /// The Symbol or Value that the effect applies to. This is optionally null.
-  PointerUnion<SymbolRefAttr, Value> value;
+  PointerUnion<SymbolRefAttr, Value, OpOperand *> value;
 
   /// Additional parameters of the effect instance. An attribute is used for
   /// type-safe structured storage and context-based uniquing. Concrete effects
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2e31487bd55a0..3efe93c300f46 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1703,11 +1703,11 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
 void AffineDmaStartOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
+  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
                        SideEffects::DefaultResource::get());
-  effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
+  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
                        SideEffects::DefaultResource::get());
-  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
                        SideEffects::DefaultResource::get());
 }
 
@@ -1793,7 +1793,7 @@ LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
 void AffineDmaWaitOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
                        SideEffects::DefaultResource::get());
 }
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 3b7b412842bfb..04a8ff30ee946 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -728,7 +728,7 @@ void MaterializeInDestinationOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
   if (isa<BaseMemRefType>(getDest().getType()))
-    effects.emplace_back(MemoryEffects::Write::get(), getDest(),
+    effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
                          SideEffects::DefaultResource::get());
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 60b911948d4a0..08259dd6597ca 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -825,7 +825,7 @@ Type GEPOp::getResultPtrElementType() {
 void LoadOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), getAddr());
+  effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
   // Volatile operations can have target-specific read-write effects on
   // memory besides the one referred to by the pointer operand.
   // Similarly, atomic operations that are monotonic or stricter cause
@@ -902,7 +902,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
 void StoreOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  effects.emplace_back(MemoryEffects::Write::get(), getAddr());
+  effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
   // Volatile operations can have target-specific read-write effects on
   // memory besides the one referred to by the pointer operand.
   // Similarly, atomic operations that are monotonic or stricter cause
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..1026d121abd17 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1128,7 +1128,8 @@ static void getGenericEffectsImpl(
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
     if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
-      effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+      effects.emplace_back(MemoryEffects::Read::get(),
+                           &linalgOp->getOpOperand(index), /*stage=*/0,
                            /*effectOnFullRegion=*/true,
                            SideEffects::DefaultResource::get());
     }
@@ -1138,13 +1139,16 @@ static void getGenericEffectsImpl(
   for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
+    unsigned operandIdx = index + inputOperandSize;
     if (linalgOp.payloadUsesValueFromOperand(
-            &linalgOp->getOpOperand(index + inputOperandSize))) {
-      effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+            &linalgOp->getOpOperand(operandIdx))) {
+      effects.emplace_back(MemoryEffects::Read::get(),
+                           &linalgOp->getOpOperand(operandIdx), /*stage=*/0,
                            /*effectOnFullRegion=*/true,
                            SideEffects::DefaultResource::get());
     }
-    effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+    effects.emplace_back(MemoryEffects::Write::get(),
+                         &linalgOp->getOpOperand(operandIdx), /*stage=*/0,
                          /*effectOnFullRegion=*/true,
                          SideEffects::DefaultResource::get());
   }
@@ -2546,20 +2550,27 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
 void SoftmaxOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  for (Value operand : getDpsInputs()) {
+  SmallVector<Value> inputOperands = getDpsInputs();
+  for (auto [index, operand] : llvm::enumerate(inputOperands)) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+    effects.emplace_back(MemoryEffects::Read::get(),
+                         &getOperation()->getOpOperand(index), /*stage=*/0,
                          /*effectOnFullRegion=*/true,
                          SideEffects::DefaultResource::get());
   }
-  for (Value operand : getDpsInits()) {
+
+  unsigned inputOperandSize = inputOperands.size();
+  for (auto [index, operand] : llvm::enumerate(getDpsInits())) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+    unsigned operandIdx = index + inputOperandSize;
+    effects.emplace_back(MemoryEffects::Read::get(),
+                         &getOperation()->getOpOperand(operandIdx), /*stage=*/0,
                          /*effectOnFullRegion=*/true,
                          SideEffects::DefaultResource::get());
-    effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+    effects.emplace_back(MemoryEffects::Write::get(),
+                         &getOperation()->getOpOperand(operandIdx), /*stage=*/0,
                          /*effectOnFullRegion=*/true,
                          SideEffects::DefaultResource::get());
   }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58951641d33ce..f528c0a7960e7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4123,7 +4123,7 @@ void TransferReadOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
   if (llvm::isa<MemRefType>(getShapedType()))
-    effects.emplace_back(MemoryEffects::Read::get(), getSource(),
+    effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
                          SideEffects::DefaultResource::get());
 }
 
@@ -4497,7 +4497,7 @@ void TransferWriteOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
   if (llvm::isa<MemRefType>(getShapedType()))
-    effects.emplace_back(MemoryEffects::Write::get(), getSource(),
+    effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
                          SideEffects::DefaultResource::get());
 }
 



More information about the Mlir-commits mailing list