[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Oct 2 01:43:28 PDT 2024


================
@@ -870,5 +870,80 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
   return reassociation;
 }
 
+// Returns true if the value is a constant float or integer.
+bool isValConstZero(Value val) {
+  return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero());
+}
+
+// Returns true if the attribute represent "all zeros".
+static bool isZeroAttr(Attribute attribute) {
+  return TypeSwitch<Attribute, bool>(attribute)
+      .Case<FloatAttr>([](auto attr) { return attr.getValueAsDouble() == 0.0; })
+      .Case<IntegerAttr>([](auto attr) { return attr.getInt() == 0; })
+      .Case<DenseElementsAttr>([](auto attr) {
+        if (!attr.getElementType().isIntOrFloat())
+          return false;
+        if (!attr.isSplat())
+          return false;
+        auto splat = attr.template getSplatValue<Attribute>();
+        return isZeroAttr(splat);
+      })
+      .Default([](auto attr) { return false; });
+}
+
+// Recurses into isZeroOp for defining ops if not immediately obvious.
+// Looks past linalg generic's argument (which don't have defining ops).
+bool isZeroTensor(Value val) {
+  if (!val)
+    return false;
+  if (isValConstZero(val))
+    return true;
+
+  Operation *defOp = nullptr;
+
+  // Block arguments don't have a defining op, but they do have an op arg.
+  if (auto arg = dyn_cast<BlockArgument>(val)) {
+    // We need to find the argument to the linalg on the same order as this one.
+    auto *linalgOp = arg.getParentRegion()->getParentOp();
+    if (!isa<linalg::GenericOp>(linalgOp))
+      return false;
+    auto index = arg.getArgNumber();
+    auto linalgArg = linalgOp->getOperand(index);
+    defOp = linalgArg.getDefiningOp();
+  } else {
+    defOp = val.getDefiningOp();
+  }
+  return isZeroOp(defOp);
+}
+
+// Recurses into isZeroTensor for operands and isZeroAttr for attributes.
+bool isZeroOp(Operation *defOp) {
+  if (!defOp)
+    return false;
+
+  return TypeSwitch<Operation *, bool>(defOp)
+      .Case<arith::ConstantOp>([&](auto op) {
+        // Dense attributes don't match APFloat.isZero().
+        Attribute attr = op.getValue();
+        return isZeroAttr(attr);
+      })
+      .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
+        if (op.getInputs().size() != 1)
+          return false;
+        return isZeroTensor(op.getInputs()[0]);
+      })
+      .Case<memref::CopyOp, memref::SubViewOp, tensor::CastOp,
+            tensor::ExtractSliceOp>(
+          [&](auto op) { return isZeroTensor(op.getSource()); })
+      .Case<memref::GetGlobalOp>([&](auto op) {
----------------
banach-space wrote:

Is there a test for this case?

https://github.com/llvm/llvm-project/pull/110514


More information about the Mlir-commits mailing list