[Mlir-commits] [mlir] fa6e433 - [mlir][tosa] Fix assertion failure in tosa-layerwise-constant-fold (#85670)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 21 06:02:26 PDT 2024


Author: Spenser Bauman
Date: 2024-03-21T09:02:21-04:00
New Revision: fa6e4338369c787710f1fe682cf6bd62348b9104

URL: https://github.com/llvm/llvm-project/commit/fa6e4338369c787710f1fe682cf6bd62348b9104
DIFF: https://github.com/llvm/llvm-project/commit/fa6e4338369c787710f1fe682cf6bd62348b9104.diff

LOG: [mlir][tosa] Fix assertion failure in tosa-layerwise-constant-fold (#85670)

The existing implementation of tosa-layerwise-constant-fold only works
for constant values backed by DenseElementsAttr. For constants which
hold DenseResourceAttrs, the folder will end up asserting at runtime, as
it assumes that the backing data can always be accessed through
ElementsAttr::getValues.

This change reworks the logic so that types types used to perform
folding are based on whether the ElementsAttr can be converted to a
range of that particular type.

---------

Co-authored-by: Spenser Bauman <sabauma at mathworks.com>
Co-authored-by: Tina Jung <tinamaria.jung at amd.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 050f8ca3f32aed..6575b39fd45a1f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -132,14 +132,17 @@ bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
   return inputOp.hasOneUse();
 }
 
-template <typename BaseType>
-DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
+template <typename RangeType>
+DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
                                 ShapedType outputType,
                                 llvm::ArrayRef<int64_t> permValues) {
+  using ElementType = std::decay_t<decltype(*std::begin(data))>;
+
+  assert(inputType.getElementType() == outputType.getElementType());
+
   if (inputType.getNumElements() == 0)
-    return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
+    return DenseElementsAttr::get(outputType, llvm::ArrayRef<ElementType>{});
 
-  auto attrValues = attr.getValues<BaseType>();
   auto inputShape = inputType.getShape();
 
   // The inverted permutation map and strides of the output are used to compute
@@ -148,10 +151,11 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
   auto outputStrides = computeStrides(outputType.getShape());
   auto invertedPermValues = invertPermutationVector(permValues);
 
-  auto initialValue = *std::begin(attrValues);
-  SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
+  auto initialValue = *std::begin(data);
+  SmallVector<ElementType> outputValues(inputType.getNumElements(),
+                                        initialValue);
 
-  for (const auto &it : llvm::enumerate(attrValues)) {
+  for (const auto &it : llvm::enumerate(data)) {
     auto srcLinearIndex = it.index();
 
     uint64_t dstLinearIndex = 0;
@@ -170,7 +174,7 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
   }
 
   return DenseElementsAttr::get(outputType,
-                                llvm::ArrayRef<BaseType>(outputValues));
+                                llvm::ArrayRef<ElementType>(outputValues));
 }
 
 // A type specialized transposition of an ElementsAttr.
@@ -180,32 +184,28 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
                             ShapedType outputType,
                             llvm::ArrayRef<int64_t> permValues) {
-  auto baseType = inputType.getElementType();
-
-  // Handle possible integer types
-  if (auto intType = dyn_cast<IntegerType>(baseType)) {
-    switch (intType.getWidth()) {
-    case 1:
-      return transposeType<bool>(attr, inputType, outputType, permValues);
-    case 8:
-      return transposeType<int8_t>(attr, inputType, outputType, permValues);
-    case 16:
-      return transposeType<int16_t>(attr, inputType, outputType, permValues);
-    case 32:
-      return transposeType<int32_t>(attr, inputType, outputType, permValues);
-    case 64:
-      return transposeType<int64_t>(attr, inputType, outputType, permValues);
-    default:
-      return transposeType<APInt>(attr, inputType, outputType, permValues);
-    }
-  }
+  if (auto data = attr.tryGetValues<bool>())
+    return transposeType(*data, inputType, outputType, permValues);
 
-  // Handle possible float types
-  if (baseType.isF32()) {
-    return transposeType<float>(attr, inputType, outputType, permValues);
-  }
+  if (auto data = attr.tryGetValues<int8_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<int16_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<int32_t>())
+    return transposeType(*data, inputType, outputType, permValues);
 
-  return transposeType<APFloat>(attr, inputType, outputType, permValues);
+  if (auto data = attr.tryGetValues<int64_t>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<float>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  if (auto data = attr.tryGetValues<APFloat>())
+    return transposeType(*data, inputType, outputType, permValues);
+
+  return nullptr;
 }
 
 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
@@ -228,14 +228,19 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     DenseIntElementsAttr permAttr;
     if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
       return failure();
-    auto permValues = llvm::to_vector<6>(llvm::map_range(
+    auto permValues = llvm::map_to_vector(
         // TOSA allows both 32- and 64-bit integer tensors here.
         permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getSExtValue(); }));
+        [](const APInt &val) { return val.getSExtValue(); });
 
     auto inputType = cast<ShapedType>(op.getInput1().getType());
 
     auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
+    if (!resultAttr) {
+      return rewriter.notifyMatchFailure(
+          op, "unsupported attribute or element type");
+    }
+
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
     return success();
   }

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 27ca3ae3c21be6..de752f31fcbaa1 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -112,6 +112,23 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
   return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
 }
 
+// CHECK-LABEL: @transpose_nofold_dense_resource
+func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
+  %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
+  %1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+
+  // CHECK: tosa.transpose
+  %2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
+  return %2 : tensor<2x2xf32>
+}
+{-#
+  dialect_resources: {
+    builtin: {
+      resource: "0x08000000010000000000000002000000000000000300000000000000"
+    }
+  }
+#-}
+
 // -----
 
 // CHECK-LABEL: @fold_add_zero_rhs_f32


        


More information about the Mlir-commits mailing list