[Mlir-commits] [mlir] [MLIR][DLTI] Enable types as keys in DLTI-query utils (PR #105995)

Rolf Morel llvmlistbot at llvm.org
Mon Aug 26 04:16:37 PDT 2024


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/105995

>From b087fdc3509e27032ee9eab7baed10b536bed63f Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 25 Aug 2024 10:54:25 -0700
Subject: [PATCH 1/5] [MLIR][DLTI] Enable types as keys in DLTI-query utils

Enable support for query functions - include transform.dlti.query - to
take types as keys. As the data layout specific attributes already
supported types as keys, this change enables querying such attributes
in the expected way.
---
 mlir/include/mlir/Dialect/DLTI/DLTI.h         |  2 +-
 .../DLTI/TransformOps/DLTITransformOps.td     |  9 +-
 mlir/lib/Dialect/DLTI/DLTI.cpp                | 26 +++++-
 .../DLTI/TransformOps/DLTITransformOps.cpp    | 11 ++-
 mlir/test/Dialect/DLTI/invalid.mlir           |  8 ++
 mlir/test/Dialect/DLTI/query.mlir             | 88 +++++++++++++++++++
 mlir/test/Dialect/DLTI/valid.mlir             | 15 ++++
 7 files changed, 149 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index a97eb523cb0631..f268fea340a6fb 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -26,7 +26,7 @@ namespace mlir {
 namespace dlti {
 /// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
 /// query interface-implementing attrs, starting from attr obtained from `op`.
-FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
+FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
                            bool emitError = false);
 } // namespace dlti
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
index 1b1bebfaab4e38..f25bb383912d45 100644
--- a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
+++ b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
@@ -26,9 +26,10 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
 
     A lookup is performed for the given `keys` at `target` op - or its closest
     interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
-    returns an attribute for a key. If more than one key is provided, the lookup
-    continues recursively, now on the returned attributes, with the condition
-    that these implement the above interface. For example if the payload IR is
+    returns an attribute for a key. Each key should be either a (quoted) string
+    or a type. If more than one key is provided, the lookup continues
+    recursively, now on the returned attributes, with the condition that these
+    implement the above interface. For example if the payload IR is
 
     ```
     module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
@@ -52,7 +53,7 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       StrArrayAttr:$keys);
+                       ArrayAttr:$keys);
   let results = (outs TransformParamTypeInterface:$associated_attr);
   let assemblyFormat =
       "$keys `at` $target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 7f8e11a1b73341..58f8799b0714d7 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -424,8 +424,8 @@ getClosestQueryable(Operation *op) {
   return std::pair(queryable, op);
 }
 
-FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
-                                 bool emitError) {
+FailureOr<Attribute>
+dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
   auto [queryable, queryOp] = getClosestQueryable(op);
   Operation *reportOp = (queryOp ? queryOp : op);
 
@@ -438,6 +438,17 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
     return failure();
   }
 
+  auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
+    if (auto strKey = llvm::dyn_cast<StringAttr>(key))
+      return "\"" + std::string(strKey.getValue()) + "\"";
+    if (auto typeKey = llvm::dyn_cast<Type>(key)) {
+      std::string buf;
+      llvm::raw_string_ostream(buf) << typeKey;
+      return buf;
+    }
+    llvm_unreachable("DataLayoutEntryKey was not `StringAttr` or `Type`");
+  };
+
   Attribute currentAttr = queryable;
   for (auto &&[idx, key] : llvm::enumerate(keys)) {
     if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
@@ -446,17 +457,24 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
         if (emitError) {
           auto diag = op->emitError() << "target op of failed DLTI query";
           diag.attachNote(reportOp->getLoc())
-              << "key " << key << " has no DLTI-mapping per attr: " << map;
+              << "key " << keyToStr(key)
+              << " has no DLTI-mapping per attr: " << map;
         }
         return failure();
       }
       currentAttr = *maybeAttr;
     } else {
       if (emitError) {
+        std::string commaSeparatedKeys;
+        llvm::interleave(
+            keys.take_front(idx), // All prior keys.
+            [&](auto key) { commaSeparatedKeys += keyToStr(key); },
+            [&]() { commaSeparatedKeys += ","; });
+
         auto diag = op->emitError() << "target op of failed DLTI query";
         diag.attachNote(reportOp->getLoc())
             << "got non-DLTI-queryable attribute upon looking up keys ["
-            << keys.take_front(idx) << "] at op";
+            << commaSeparatedKeys << "] at op";
       }
       return failure();
     }
diff --git a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
index 90aef82bddff00..2f171a8375b46d 100644
--- a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
+++ b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
@@ -33,7 +33,16 @@ void transform::QueryOp::getEffects(
 DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
     transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results, TransformState &state) {
-  auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
+  auto keys = SmallVector<DataLayoutEntryKey>();
+  for (Attribute key : getKeys()) {
+    if (auto strKey = dyn_cast<StringAttr>(key))
+      keys.push_back(strKey);
+    else if (auto typeKey = dyn_cast<TypeAttr>(key))
+      keys.push_back(typeKey.getValue());
+    else
+      return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
+                                 "only StringAttr and TypeAttr are allowed");
+  }
 
   FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);
 
diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 05f919fa256713..4b04f0195ef823 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -33,6 +33,14 @@
 
 // -----
 
+// expected-error at below {{repeated layout entry key: 'i32'}}
+"test.unknown_op"() { test.unknown_attr = #dlti.map<
+  #dlti.dl_entry<i32, 42>,
+  #dlti.dl_entry<i32, 42>
+>} : () -> ()
+
+// -----
+
 // expected-error at below {{repeated layout entry key: 'i32'}}
 "test.unknown_op"() { test.unknown_attr = #dlti.dl_spec<
   #dlti.dl_entry<i32, 42>,
diff --git a/mlir/test/Dialect/DLTI/query.mlir b/mlir/test/Dialect/DLTI/query.mlir
index 10e91afd2ca7e1..e449c2c44bc617 100644
--- a/mlir/test/Dialect/DLTI/query.mlir
+++ b/mlir/test/Dialect/DLTI/query.mlir
@@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-remark @below {{associated attr 42 : i32}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, 42 : i32>>} {
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+    %param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
+    transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-remark @below {{associated attr 32 : i32}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, #dlti.map<#dlti.dl_entry<"width_in_bits", 32 : i32>>>>} {
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+    %param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param
+    transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-remark @below {{width in bits of i32 = 32 : i64}}
+// expected-remark @below {{width in bits of f64 = 64 : i64}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>, #dlti.dl_entry<f64, 64>>>>} {
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+    %i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param
+    %f64bits  = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param
+    transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op
+    transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // expected-remark @below {{associated attr 42 : i32}}
 module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
   func.func private @f()
@@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}}
+module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<i32, 32 : i32>>} {
+  // expected-error @below {{target op of failed DLTI query}}
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{'transform.dlti.query' op failed to apply}}
+    %param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param
+    transform.yield
+  }
+}
+
+// -----
+
 module {
   // expected-error @below {{target op of failed DLTI query}}
   // expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
@@ -353,6 +424,23 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry<i32, 32 : i64>>}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>>>>} {
+  // expected-error @below {{target op of failed DLTI query}}
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{'transform.dlti.query' op failed to apply}}
+    %param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param
+    transform.yield
+  }
+}
+
+// -----
+
 module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
   func.func private @f()
 }
diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir
index 4133eac5424ce8..31c925e5cb5bed 100644
--- a/mlir/test/Dialect/DLTI/valid.mlir
+++ b/mlir/test/Dialect/DLTI/valid.mlir
@@ -206,3 +206,18 @@ module attributes {
     "GPU": #dlti.target_device_spec<
       #dlti.dl_entry<"L1_cache_size_in_bytes", "128">>
   >} {}
+
+
+// -----
+
+// CHECK: "test.op_with_dlti_map"() ({
+// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>}
+"test.op_with_dlti_map"() ({
+}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> ()
+
+// -----
+
+// CHECK: "test.op_with_dlti_map"() ({
+// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<i32, 42 : i64>>}
+"test.op_with_dlti_map"() ({
+}) { dlti.map = #dlti.map<#dlti.dl_entry<i32, 42>> } : () -> ()

>From 2b838c281b3aa98aa88cc39c3225326d9d3e1d28 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 25 Aug 2024 16:53:44 -0700
Subject: [PATCH 2/5] Fully delegate string repr to key's print method

In response to @rengolin's review
---
 mlir/lib/Dialect/DLTI/DLTI.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 58f8799b0714d7..70030f79c75f74 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -439,14 +439,12 @@ dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
   }
 
   auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
-    if (auto strKey = llvm::dyn_cast<StringAttr>(key))
-      return "\"" + std::string(strKey.getValue()) + "\"";
-    if (auto typeKey = llvm::dyn_cast<Type>(key)) {
-      std::string buf;
-      llvm::raw_string_ostream(buf) << typeKey;
-      return buf;
-    }
-    llvm_unreachable("DataLayoutEntryKey was not `StringAttr` or `Type`");
+    std::string buf;
+    llvm::TypeSwitch<DataLayoutEntryKey>(key)
+        .Case<StringAttr, Type>( // The only two kinds of key we know of.
+            [&](auto key) { llvm::raw_string_ostream(buf) << key; })
+        .Default([](auto) { llvm_unreachable("unexpected entry key kind"); });
+    return buf;
   };
 
   Attribute currentAttr = queryable;

>From ddb54f7c8fc86adad9a514a67bcaa5e9aafc2f38 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 25 Aug 2024 17:15:36 -0700
Subject: [PATCH 3/5] Adjust example to make more sense: show how to
 build/query a DLTI set

In response to @rengolin's review
---
 mlir/test/Dialect/DLTI/query.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/DLTI/query.mlir b/mlir/test/Dialect/DLTI/query.mlir
index e449c2c44bc617..748523693c69da 100644
--- a/mlir/test/Dialect/DLTI/query.mlir
+++ b/mlir/test/Dialect/DLTI/query.mlir
@@ -17,8 +17,8 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// expected-remark @below {{associated attr 42 : i32}}
-module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, 42 : i32>>} {
+// expected-remark @below {{i32 present in set : unit}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, unit>>} {
   func.func private @f()
 }
 
@@ -27,7 +27,7 @@ module attributes {transform.with_named_sequence} {
     %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
     %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
     %param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
-    transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+    transform.debug.emit_param_as_remark %param, "i32 present in set :" at %module : !transform.any_param, !transform.any_op
     transform.yield
   }
 }

>From fd1702142aec4fd3a7f31a696523a147f342f804 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 26 Aug 2024 04:14:58 -0700
Subject: [PATCH 4/5] Simpler syntax for introducing variable, per @adam-smnk's
 review

---
 mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
index 2f171a8375b46d..02c41b4fe8113f 100644
--- a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
+++ b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
@@ -33,7 +33,7 @@ void transform::QueryOp::getEffects(
 DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
     transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results, TransformState &state) {
-  auto keys = SmallVector<DataLayoutEntryKey>();
+  SmallVector<DataLayoutEntryKey> keys;
   for (Attribute key : getKeys()) {
     if (auto strKey = dyn_cast<StringAttr>(key))
       keys.push_back(strKey);

>From feadc553bbf5ded3f627181f446c7d45ceb27d73 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 26 Aug 2024 04:16:02 -0700
Subject: [PATCH 5/5] Test case for wrong kind of key

---
 mlir/test/Dialect/DLTI/query.mlir | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mlir/test/Dialect/DLTI/query.mlir b/mlir/test/Dialect/DLTI/query.mlir
index 748523693c69da..1a1511d1c718d1 100644
--- a/mlir/test/Dialect/DLTI/query.mlir
+++ b/mlir/test/Dialect/DLTI/query.mlir
@@ -445,6 +445,21 @@ module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32
   func.func private @f()
 }
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{'transform.dlti.query' keys of wrong type: only StringAttr and TypeAttr are allowed}}
+    %param = transform.dlti.query [1] at %funcs : (!transform.any_op) -> !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
+  func.func private @f()
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg: !transform.any_op) {
     %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op



More information about the Mlir-commits mailing list