[Mlir-commits] [mlir] 0b7362c - [mlir][arith] Add result pretty printing for constant vscale values (#83565)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 5 02:48:20 PDT 2024


Author: Benjamin Maxwell
Date: 2024-04-05T10:48:16+01:00
New Revision: 0b7362c257ff7b656c31266b4f9b8485a7ba4033

URL: https://github.com/llvm/llvm-project/commit/0b7362c257ff7b656c31266b4f9b8485a7ba4033
DIFF: https://github.com/llvm/llvm-project/commit/0b7362c257ff7b656c31266b4f9b8485a7ba4033.diff

LOG: [mlir][arith] Add result pretty printing for constant vscale values (#83565)

In scalable code it is very common to have constant multiples of vscale,
e.g. `4 * vscale`. This updates `arith.muli` to pretty print the result
name in cases like this, so `4 * vscale` would be `%c4_vscale`.

This makes reading IR dumps of scalable code a little nicer.

Added: 
    mlir/test/Dialect/Arith/vscale_constants.mlir

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ead19c69a0831c..4e4c6fd601777b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -343,7 +343,9 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
+def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli",
+  [Commutative, DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]
+> {
   let summary = [{
     Integer multiplication operation.
   }];

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ce0602c853e3c3..1d68a4f7292b53 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -423,6 +423,33 @@ OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
       [](const APInt &a, const APInt &b) { return a * b; });
 }
 
+void arith::MulIOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  if (!isa<IndexType>(getType()))
+    return;
+
+  // Match vector.vscale by name to avoid depending on the vector dialect (which
+  // is a circular dependency).
+  auto isVscale = [](Operation *op) {
+    return op && op->getName().getStringRef() == "vector.vscale";
+  };
+
+  IntegerAttr baseValue;
+  auto isVscaleExpr = [&](Value a, Value b) {
+    return matchPattern(a, m_Constant(&baseValue)) &&
+           isVscale(b.getDefiningOp());
+  };
+
+  if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
+    return;
+
+  // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
+  SmallString<32> specialNameBuffer;
+  llvm::raw_svector_ostream specialName(specialNameBuffer);
+  specialName << 'c' << baseValue.getInt() << "_vscale";
+  setNameFn(getResult(), specialName.str());
+}
+
 void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
   patterns.add<MulIMulIConstant>(context);

diff  --git a/mlir/test/Dialect/Arith/vscale_constants.mlir b/mlir/test/Dialect/Arith/vscale_constants.mlir
new file mode 100644
index 00000000000000..324766f49980f4
--- /dev/null
+++ b/mlir/test/Dialect/Arith/vscale_constants.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// Note: This test is checking value names (so deliberately is not using a regex match).
+
+func.func @test_vscale_constant_names() {
+  %vscale = vector.vscale
+  %c8 = arith.constant 8 : index
+  // CHECK: %c8_vscale = arith.muli
+  %0 = arith.muli %vscale, %c8 : index
+  %c10 = arith.constant 10 : index
+  // CHECK: %c10_vscale = arith.muli
+  %1 = arith.muli %c10, %vscale : index
+  return
+}


        


More information about the Mlir-commits mailing list