[Mlir-commits] [mlir] [mlir][tosa] Fix assertion failure in tosa-layerwise-constant-fold (PR #85670)

Spenser Bauman llvmlistbot at llvm.org
Wed Mar 20 12:40:03 PDT 2024


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/85670

>From 894852be2cf6cb8933c52cdbd8f8f93370beee47 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Mon, 18 Mar 2024 13:04:16 -0400
Subject: [PATCH 1/2] [mlir][tosa] Fix assertion failure in
 tosa-layerwise-constant-fold

The existing implementation of tosa-layerwise-constant-fold only works
for constant values backed by DenseElementsAttr. For constants which
hold DenseResourceAttrs, the folder will end up asserting at runtime, as
it assumes that the backing data can always be accessed through
ElementsAttr::getValues.

This change reworks the logic so that types types used to perform
folding are based on whether the ElementsAttr can be converted to
a range of that particular type.
---
 .../Dialect/Tosa/Transforms/TosaFolders.cpp   | 73 ++++++++++---------
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  | 19 +++++
 2 files changed, 58 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 050f8ca3f32aed..6575b39fd45a1f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -132,14 +132,17 @@ bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
   return inputOp.hasOneUse();
 }
 
-template <typename BaseType>
-DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
+template <typename RangeType>
+DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
                                 ShapedType outputType,
                                 llvm::ArrayRef<int64_t> permValues) {
+  using ElementType = std::decay_t<decltype(*std::begin(data))>;
+
+  assert(inputType.getElementType() == outputType.getElementType());
+
   if (inputType.getNumElements() == 0)
-    return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
+    return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
 
-  auto attrValues = attr.getValues<BaseType>();
   auto inputShape = inputType.getShape();
 
   // The inverted permutation map and strides of the output are used to compute
@@ -148,10 +151,11 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
   auto outputStrides = computeStrides(outputType.getShape());
   auto invertedPermValues = invertPermutationVector(permValues);
 
-  auto initialValue = *std::begin(attrValues);
-  SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
+  auto initialValue = *std::begin(data);
+  SmallVector<ElementType> outputValues(inputType.getNumElements(),
+                                        initialValue);
 
-  for (const auto &it : llvm::enumerate(attrValues)) {
+  for (const auto &it : llvm::enumerate(data)) {
     auto srcLinearIndex = it.index();
 
     uint64_t dstLinearIndex = 0;
@@ -170,7 +174,7 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
   }
 
   return DenseElementsAttr::get(outputType,
-                                llvm::ArrayRef<BaseType>(outputValues));
+                                llvm::ArrayRef<ElementType>(outputValues));
 }
 
 // A type specialized transposition of an ElementsAttr.
@@ -180,32 +184,28 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
                             ShapedType outputType,
                             llvm::ArrayRef<int64_t> permValues) {
-  auto baseType = inputType.getElementType();
-
-  // Handle possible integer types
-  if (auto intType = dyn_cast<IntegerType>(baseType)) {
-    switch (intType.getWidth()) {
-    case 1:
-      return transposeType<bool>(attr, inputType, outputType, permValues);
-    case 8:
-      return transposeType<int8_t>(attr, inputType, outputType, permValues);
-    case 16:
-      return transposeType<int16_t>(attr, inputType, outputType, permValues);
-    case 32:
-      return transposeType<int32_t>(attr, inputType, outputType, permValues);
-    case 64:
-      return transposeType<int64_t>(attr, inputType, outputType, permValues);
-    default:
-      return transposeType<APInt>(attr, inputType, outputType, permValues);
-    }
-  }
+  if (auto data = attr.tryGetValues<bool>())
+    return transposeType(*data, inputType, outputType, permValues);
 
-  // Handle possible float types
-  if (baseType.isF32()) {
-    return transposeType<float>(attr, inputType, outputType, permValues);
-  }
+  if (auto data = attr.tryGetValues<int8_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<int16_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<int32_t>())
+    return transposeType(*data, inputType, outputType, permValues);
 
-  return transposeType<APFloat>(attr, inputType, outputType, permValues);
+  if (auto data = attr.tryGetValues<int64_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<float>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<APFloat>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  return nullptr;
 }
 
 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
@@ -228,14 +228,19 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     DenseIntElementsAttr permAttr;
     if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
       return failure();
-    auto permValues = llvm::to_vector<6>(llvm::map_range(
+    auto permValues = llvm::map_to_vector(
         // TOSA allows both 32- and 64-bit integer tensors here.
         permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getSExtValue(); }));
+        [](const APInt &val) { return val.getSExtValue(); });
 
     auto inputType = cast<ShapedType>(op.getInput1().getType());
 
     auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
+    if (!resultAttr) {
+      return rewriter.notifyMatchFailure(
+          op, "unsupported attribute or element type");
+    }
+
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
     return success();
   }
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 27ca3ae3c21be6..c07ef3b60c7da3 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -114,6 +114,25 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
 
 // -----
 
+// CHECK-LABEL: @transpose_dense_resource
+func.func @transpose_dense_resource() -> tensor<2x2xf32> {
+    %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
+    %1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+
+    // CHECK: tosa.transpose
+    %2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
+    return %2 : tensor<2x2xf32>
+}
+{-#
+  dialect_resources: {
+    builtin: {
+      resource: "0x08000000010000000000000002000000000000000300000000000000"
+    }
+  }
+#-}
+
+// -----
+
 // CHECK-LABEL: @fold_add_zero_rhs_f32
 func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>

>From ba0be5e57cbdde678b057c00237e54cbf3ea9dfa Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail.com>
Date: Tue, 19 Mar 2024 09:51:35 -0400
Subject: [PATCH 2/2] Update mlir/test/Dialect/Tosa/constant-op-fold.mlir

Apply TinaAMD's suggested renaming

Co-authored-by: Tina Jung <tinamaria.jung at amd.com>
---
 mlir/test/Dialect/Tosa/constant-op-fold.mlir | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index c07ef3b60c7da3..de752f31fcbaa1 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -112,16 +112,14 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
   return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
 }
 
-// -----
-
-// CHECK-LABEL: @transpose_dense_resource
-func.func @transpose_dense_resource() -> tensor<2x2xf32> {
-    %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
-    %1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+// CHECK-LABEL: @transpose_nofold_dense_resource
+func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
+  %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
+  %1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
 
-    // CHECK: tosa.transpose
-    %2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
-    return %2 : tensor<2x2xf32>
+  // CHECK: tosa.transpose
+  %2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
+  return %2 : tensor<2x2xf32>
 }
 {-#
   dialect_resources: {



More information about the Mlir-commits mailing list