[Mlir-commits] [mlir] [mlir][linalg] Add runtime verification for linalg.generic (PR #89342)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 18 20:38:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-memref

Author: Ryan Holt (ryan-holt-1)

<details>
<summary>Changes</summary>

This commit implements runtime verification for `linalg.generic` using the existing `RuntimeVerifiableOpInterface`. The verification checks that the runtime sizes of the operands match the runtime sizes inferred by composing the loop ranges with the generic's indexing maps.

---

Patch is 28.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89342.diff


9 Files Affected:

- (added) mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h (+22) 
- (modified) mlir/include/mlir/InitAllDialects.h (+2) 
- (modified) mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td (+6) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp (+125) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+21-33) 
- (modified) mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp (+22) 
- (added) mlir/test/Dialect/Linalg/runtime-verification.mlir (+43) 
- (added) mlir/test/Integration/Dialect/Linalg/CPU/generic-runtime-verification.mlir (+279) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
new file mode 100644
index 00000000000000..ae11a3da0fe838
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h
@@ -0,0 +1,22 @@
+//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
+#define MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerRuntimeVerifiableOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H
+
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c4d788cf8ed316..d9db21073e15c7 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -45,6 +45,7 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
@@ -161,6 +162,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
   linalg::registerAllDialectInterfaceImplementations(registry);
+  linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index d5f11d00cc3d2a..6fd0df59d9d2e0 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -35,6 +35,12 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
                     "::mlir::Location":$loc)
     >,
   ];
+
+  let extraClassDeclaration = [{
+    /// Generate the error message that will be printed to the user when 
+    /// verification fails.
+    static std::string generateErrorMessage(Operation *op, const std::string &msg);
+  }];
 }
 
 #endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 513c54de5d7bfc..1cf94f0bfb39f3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   NamedOpConversions.cpp
   Padding.cpp
   Promotion.cpp
+  RuntimeOpVerification.cpp
   Specialize.cpp
   Split.cpp
   SplitReduction.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
new file mode 100644
index 00000000000000..3ca0bc78da16be
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -0,0 +1,125 @@
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexAttrs.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+namespace linalg {
+namespace {
+/// Verify that the runtime sizes of the operands to linalg.generic are
+/// compatible with the runtime sizes inferred by composing the loop ranges with
+/// the generic's indexing maps.
+struct GenericOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<GenericOpInterface,
+                                                         linalg::GenericOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto linalgOp = llvm::cast<LinalgOp>(op);
+
+    SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
+    auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
+
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Subtract one from the loop ends before composing with the indexing map
+    transform(ends, ends.begin(), [&](OpFoldResult end) {
+      auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+      return builder.createOrFold<index::SubOp>(loc, endValue, one);
+    });
+
+    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+      auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
+          builder, loc, indexingMap, starts);
+      auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
+          builder, loc, indexingMap, ends);
+
+      for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
+        auto startIndex =
+            getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
+        auto endIndex =
+            getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
+
+        // Generate:
+        //   minIndex = min(startIndex, endIndex)
+        //   assert(minIndex >= 0)
+        // To ensure we do not generate a negative index. We take the minimum of
+        // the start and end indices in order to handle reverse loops such as
+        // `affine_map<(i) -> (3 - i)>`
+        auto min =
+            builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
+        auto cmpOp = builder.createOrFold<index::CmpOp>(
+            loc, index::IndexCmpPredicate::SGE, min, zero);
+        auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+            linalgOp, "unexpected negative result on dimension #" +
+                          std::to_string(dim) + " of input/output operand #" +
+                          std::to_string(opOperand.getOperandNumber()));
+        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
+
+        // Generate:
+        //   inferredDimSize = max(startIndex, endIndex) + 1
+        //   actualDimSize = dim(operand)
+        //   assert(inferredDimSize <= actualDimSize)
+        // To ensure that we do not index past the bounds of the operands.
+        auto max =
+            builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
+
+        auto inferredDimSize =
+            builder.createOrFold<index::AddOp>(loc, max, one);
+
+        auto actualDimSize =
+            createOrFoldDimOp(builder, loc, opOperand.get(), dim);
+
+        // Similar to the verifier, when the affine expression in the indexing
+        // map is complicated, we just check that the inferred dimension sizes
+        // are in the boundary of the operands' size. Being more precise than
+        // that is difficult.
+        auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
+                             ? index::IndexCmpPredicate::EQ
+                             : index::IndexCmpPredicate::SLE;
+
+        cmpOp = builder.createOrFold<index::CmpOp>(
+            loc, predicate, inferredDimSize, actualDimSize);
+        msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+            linalgOp, "dimension #" + std::to_string(dim) +
+                          " of input/output operand #" +
+                          std::to_string(opOperand.getOperandNumber()) +
+                          " is incompatible with inferred dimension size");
+        builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
+      }
+    }
+  }
+};
+} // namespace
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
+    GenericOp::attachInterface<GenericOpInterface>(*ctx);
+
+    // Load additional dialects of which ops may get created.
+    ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
+                     cf::ControlFlowDialect, index::IndexDialect,
+                     tensor::TensorDialect>();
+  });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 05b813a3b1e908..450bfa0cec0c7f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -20,25 +20,6 @@
 
 using namespace mlir;
 
-/// Generate an error message string for the given op and the specified error.
-static std::string generateErrorMessage(Operation *op, const std::string &msg) {
-  std::string buffer;
-  llvm::raw_string_ostream stream(buffer);
-  OpPrintingFlags flags;
-  // We may generate a lot of error messages and so we need to ensure the
-  // printing is fast.
-  flags.elideLargeElementsAttrs();
-  flags.printGenericOpForm();
-  flags.skipRegions();
-  flags.useLocalScope();
-  stream << "ERROR: Runtime op verification failed\n";
-  op->print(stream, flags);
-  stream << "\n^ " << msg;
-  stream << "\nLocation: ";
-  op->getLoc().print(stream);
-  return stream.str();
-}
-
 namespace mlir {
 namespace memref {
 namespace {
@@ -62,8 +43,10 @@ struct CastOpInterface
           builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
       Value isSameRank = builder.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::eq, srcRank, resultRank);
-      builder.create<cf::AssertOp>(loc, isSameRank,
-                                   generateErrorMessage(op, "rank mismatch"));
+      builder.create<cf::AssertOp>(
+          loc, isSameRank,
+          RuntimeVerifiableOpInterface::generateErrorMessage(op,
+                                                             "rank mismatch"));
     }
 
     // Get source offset and strides. We do not have an op to get offsets and
@@ -101,8 +84,8 @@ struct CastOpInterface
           loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
       builder.create<cf::AssertOp>(
           loc, isSameSz,
-          generateErrorMessage(op, "size mismatch of dim " +
-                                       std::to_string(it.index())));
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "size mismatch of dim " + std::to_string(it.index())));
     }
 
     // Get result offset and strides.
@@ -119,8 +102,10 @@ struct CastOpInterface
           builder.create<arith::ConstantIndexOp>(loc, resultOffset);
       Value isSameOffset = builder.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
-      builder.create<cf::AssertOp>(loc, isSameOffset,
-                                   generateErrorMessage(op, "offset mismatch"));
+      builder.create<cf::AssertOp>(
+          loc, isSameOffset,
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "offset mismatch"));
     }
 
     // Check strides.
@@ -137,8 +122,8 @@ struct CastOpInterface
           loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
       builder.create<cf::AssertOp>(
           loc, isSameStride,
-          generateErrorMessage(op, "stride mismatch of dim " +
-                                       std::to_string(it.index())));
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "stride mismatch of dim " + std::to_string(it.index())));
     }
   }
 };
@@ -178,7 +163,9 @@ struct LoadStoreOpInterface
                 : andOp;
     }
     builder.create<cf::AssertOp>(
-        loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
+        loc, assertCond,
+        RuntimeVerifiableOpInterface::generateErrorMessage(
+            op, "out-of-bounds access"));
   }
 };
 
@@ -248,7 +235,7 @@ struct ReinterpretCastOpInterface
 
     builder.create<cf::AssertOp>(
         loc, assertCond,
-        generateErrorMessage(
+        RuntimeVerifiableOpInterface::generateErrorMessage(
             op,
             "result of reinterpret_cast is out-of-bounds of the base memref"));
   }
@@ -293,8 +280,8 @@ struct SubViewOpInterface
 
     builder.create<cf::AssertOp>(
         loc, assertCond,
-        generateErrorMessage(op,
-                             "subview is out-of-bounds of the base memref"));
+        RuntimeVerifiableOpInterface::generateErrorMessage(
+            op, "subview is out-of-bounds of the base memref"));
   }
 };
 
@@ -334,8 +321,9 @@ struct ExpandShapeOpInterface
           builder.create<arith::ConstantIndexOp>(loc, 0));
       builder.create<cf::AssertOp>(
           loc, isModZero,
-          generateErrorMessage(op, "static result dims in reassoc group do not "
-                                   "divide src dim evenly"));
+          RuntimeVerifiableOpInterface::generateErrorMessage(
+              op, "static result dims in reassoc group do not "
+                  "divide src dim evenly"));
     }
   }
 };
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 9205d8d8c34a29..e823b5df179c50 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -11,6 +11,28 @@
 namespace mlir {
 class Location;
 class OpBuilder;
+
+/// Generate an error message string for the given op and the specified error.
+std::string
+RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
+                                                   const std::string &msg) {
+  std::string buffer;
+  llvm::raw_string_ostream stream(buffer);
+  OpPrintingFlags flags;
+  // We may generate a lot of error messages and so we need to ensure the
+  // printing is fast.
+  flags.elideLargeElementsAttrs();
+  flags.printGenericOpForm();
+  flags.skipRegions();
+  flags.useLocalScope();
+  stream << "ERROR: Runtime op verification failed\n";
+  op->print(stream, flags);
+  stream << "\n^ " << msg;
+  stream << "\nLocation: ";
+  op->getLoc().print(stream);
+  return stream.str();
+}
+
 } // namespace mlir
 
 /// Include the definitions of the interface.
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
new file mode 100644
index 00000000000000..a4f29d8457e589
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+
+// Most of the tests for linalg runtime-verification are implemented as integration tests.
+
+#identity = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @static_dims
+func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
+    // CHECK: %[[TRUE:.*]] = index.bool.constant true
+    // CHECK: cf.assert %[[TRUE]]
+    %result = tensor.empty() : tensor<5xf32> 
+    %0 = linalg.generic {
+      indexing_maps = [#identity, #identity, #identity],
+      iterator_types = ["parallel"]
+    } ins(%arg0, %arg1 : tensor<5xf32>, tensor<5xf32>)
+      outs(%result : tensor<5xf32>) {
+      ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
+        %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
+        linalg.yield %tmp1 : f32
+    } -> tensor<5xf32>
+    return %0 : tensor<5xf32>
+}
+
+// -----
+
+#map = affine_map<() -> ()>
+
+// CHECK-LABEL: @scalars
+func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
+    // No runtime checks are required if the operands are all scalars
+    // CHECK-NOT: cf.assert
+    %result = tensor.empty() : tensor<f32> 
+    %0 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = []
+    } ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
+      outs(%result : tensor<f32>) {
+      ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) :
+        %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32
+        linalg.yield %tmp1 : f32
+    } -> tensor<f32>
+    return %0 : tensor<f32>
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/generic-runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/generic-runtime-verification.mlir
new file mode 100644
index 00000000000000..d159d6ccd6c543
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/generic-runtime-verification.mlir
@@ -0,0 +1,279 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -convert-linalg-to-loops \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-scf-to-cf \
+// RUN: -test-cf-assert \
+// RUN: -convert-index-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils \
+// RUN:     -shared-libs=%mlir_c_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+  %c5x = arith.constant dense<0.0> : tensor<5xf32>
+  %c4x = arith.constant dense<0.0> : tensor<4xf32>
+  %d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32>
+  %d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32>
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+  // CHECK: ERROR: Runtime op verification failed
+  // CHECK: linalg.generic
+  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+  func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+  // CHECK: ERROR: Runtime op verification failed
+  // CHECK: linalg.generic
+  // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size
+  func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
+
+  %c1x1 = arith.constant dense<0.0> : tensor<1x1xf32>
+  %c1x4 = arith.constant dense<0.0> : tensor<1x4xf32>
+  %c4x4 = arith.constant dense<0.0> : tensor<4x4xf32>
+  %c4x5 = arith.constant dense<0.0> : tensor<4x5xf32>
+  %c5x4 = arith.constant dense<0.0> : tensor<5x4xf32>
+  %d1x1 = tensor.cast %c1x1 : tensor<1x1xf32> to tensor<?x?xf32>
+  %d1x4 = tensor.cast %c1x4 : tensor<1x4xf32> to tensor<?x?xf32>
+  %d4x4 = tensor.cast %c4x4 : tensor<4x4xf32> to tensor<?x?xf32>
+  %d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32>
+  %d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32>
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
+  // CHECK: ERROR: Runtime op verification failed
+  // CHECK: linalg.generic
+  // CHECK: ^ dimension #1 of input/output operand...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/89342


More information about the Mlir-commits mailing list