[Mlir-commits] [mlir] [mlir][tosa] Fix assertion failure in tosa-layerwise-constant-fold (PR #85670)

Spenser Bauman llvmlistbot at llvm.org
Mon Mar 18 10:37:34 PDT 2024


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/85670

The existing implementation of tosa-layerwise-constant-fold only works for constant values backed by DenseElementsAttr. For constants which hold DenseResourceAttrs, the folder will end up asserting at runtime, as it assumes that the backing data can always be accessed through ElementsAttr::getValues.

This change reworks the logic so that types types used to perform folding are based on whether the ElementsAttr can be converted to a range of that particular type.

>From 7981c1d4d0026743ab6add99271da4ad33d9e617 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Mon, 18 Mar 2024 13:04:16 -0400
Subject: [PATCH] [mlir][tosa] Fix assertion failure in
 tosa-layerwise-constant-fold

The existing implementation of tosa-layerwise-constant-fold only works
for constant values backed by DenseElementsAttr. For constants which
hold DenseResourceAttrs, the folder will end up asserting at runtime, as
it assumes that the backing data can always be accessed through
ElementsAttr::getValues.

This change reworks the logic so that types types used to perform
folding are based on whether the ElementsAttr can be converted to
a range of that particular type.
---
 .../Dialect/Tosa/Transforms/TosaFolders.cpp   | 73 ++++++++++---------
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  | 19 +++++
 2 files changed, 58 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 050f8ca3f32aed..6575b39fd45a1f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -132,14 +132,17 @@ bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
   return inputOp.hasOneUse();
 }
 
-template <typename BaseType>
-DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
+template <typename RangeType>
+DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
                                 ShapedType outputType,
                                 llvm::ArrayRef<int64_t> permValues) {
+  using ElementType = std::decay_t<decltype(*std::begin(data))>;
+
+  assert(inputType.getElementType() == outputType.getElementType());
+
   if (inputType.getNumElements() == 0)
-    return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
+    return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
 
-  auto attrValues = attr.getValues<BaseType>();
   auto inputShape = inputType.getShape();
 
   // The inverted permutation map and strides of the output are used to compute
@@ -148,10 +151,11 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
   auto outputStrides = computeStrides(outputType.getShape());
   auto invertedPermValues = invertPermutationVector(permValues);
 
-  auto initialValue = *std::begin(attrValues);
-  SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
+  auto initialValue = *std::begin(data);
+  SmallVector<ElementType> outputValues(inputType.getNumElements(),
+                                        initialValue);
 
-  for (const auto &it : llvm::enumerate(attrValues)) {
+  for (const auto &it : llvm::enumerate(data)) {
     auto srcLinearIndex = it.index();
 
     uint64_t dstLinearIndex = 0;
@@ -170,7 +174,7 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
   }
 
   return DenseElementsAttr::get(outputType,
-                                llvm::ArrayRef<BaseType>(outputValues));
+                                llvm::ArrayRef<ElementType>(outputValues));
 }
 
 // A type specialized transposition of an ElementsAttr.
@@ -180,32 +184,28 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
                             ShapedType outputType,
                             llvm::ArrayRef<int64_t> permValues) {
-  auto baseType = inputType.getElementType();
-
-  // Handle possible integer types
-  if (auto intType = dyn_cast<IntegerType>(baseType)) {
-    switch (intType.getWidth()) {
-    case 1:
-      return transposeType<bool>(attr, inputType, outputType, permValues);
-    case 8:
-      return transposeType<int8_t>(attr, inputType, outputType, permValues);
-    case 16:
-      return transposeType<int16_t>(attr, inputType, outputType, permValues);
-    case 32:
-      return transposeType<int32_t>(attr, inputType, outputType, permValues);
-    case 64:
-      return transposeType<int64_t>(attr, inputType, outputType, permValues);
-    default:
-      return transposeType<APInt>(attr, inputType, outputType, permValues);
-    }
-  }
+  if (auto data = attr.tryGetValues<bool>())
+    return transposeType(*data, inputType, outputType, permValues);
 
-  // Handle possible float types
-  if (baseType.isF32()) {
-    return transposeType<float>(attr, inputType, outputType, permValues);
-  }
+  if (auto data = attr.tryGetValues<int8_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<int16_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<int32_t>())
+    return transposeType(*data, inputType, outputType, permValues);
 
-  return transposeType<APFloat>(attr, inputType, outputType, permValues);
+  if (auto data = attr.tryGetValues<int64_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<float>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<APFloat>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  return nullptr;
 }
 
 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
@@ -228,14 +228,19 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     DenseIntElementsAttr permAttr;
     if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
       return failure();
-    auto permValues = llvm::to_vector<6>(llvm::map_range(
+    auto permValues = llvm::map_to_vector(
         // TOSA allows both 32- and 64-bit integer tensors here.
         permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getSExtValue(); }));
+        [](const APInt &val) { return val.getSExtValue(); });
 
     auto inputType = cast<ShapedType>(op.getInput1().getType());
 
     auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
+    if (!resultAttr) {
+      return rewriter.notifyMatchFailure(
+          op, "unsupported attribute or element type");
+    }
+
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
     return success();
   }
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 27ca3ae3c21be6..c07ef3b60c7da3 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -114,6 +114,25 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
 
 // -----
 
+// CHECK-LABEL: @transpose_dense_resource
+func.func @transpose_dense_resource() -> tensor<2x2xf32> {
+    %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
+    %1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+
+    // CHECK: tosa.transpose
+    %2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
+    return %2 : tensor<2x2xf32>
+}
+{-#
+  dialect_resources: {
+    builtin: {
+      resource: "0x08000000010000000000000002000000000000000300000000000000"
+    }
+  }
+#-}
+
+// -----
+
 // CHECK-LABEL: @fold_add_zero_rhs_f32
 func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>



More information about the Mlir-commits mailing list