LCOV - code coverage report
Current view: top level - backend/catalog - emulator_ml_tvf_extensions_test.cc (source / functions) Coverage Total Hit
Test: _coverage_report.dat Lines: 98.6 % 70 69
Test Date: 2026-07-02 21:01:18 Functions: 100.0 % 5 5

            Line data    Source code
       1              : #include "backend/catalog/emulator_ml_tvf_extensions.h"
       2              : 
       3              : #include <memory>
       4              : #include <string>
       5              : #include <vector>
       6              : 
       7              : #include "absl/status/status.h"
       8              : #include "absl/strings/str_join.h"
       9              : #include "backend/catalog/emulator_ml_test_catalog.h"
      10              : #include "googlesql/public/analyzer.h"
      11              : #include "googlesql/public/analyzer_options.h"
      12              : #include "googlesql/public/builtin_function_options.h"
      13              : #include "googlesql/public/catalog.h"
      14              : #include "googlesql/public/language_options.h"
      15              : #include "googlesql/public/options.pb.h"
      16              : #include "googlesql/public/types/type_factory.h"
      17              : #include "gtest/gtest.h"
      18              : 
      19              : namespace bigquery_emulator {
      20              : namespace backend {
      21              : namespace catalog {
      22              : namespace {
      23              : 
      24              : class EmulatorMlTvfCatalogTest : public ::testing::Test {
      25              :  protected:
      26            2 :   void SetUp() override {
      27            2 :     type_factory_ = std::make_unique<::googlesql::TypeFactory>();
      28            2 :     catalog_ = std::make_unique<EmulatorMlTestCatalog>("test_catalog",
      29            2 :                                                        type_factory_.get());
      30            2 :     ::googlesql::LanguageOptions language;
      31            2 :     language.EnableMaximumLanguageFeaturesForDevelopment();
      32            2 :     language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
      33            2 :     language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
      34            2 :     ASSERT_TRUE(catalog_
      35            2 :                     ->AddBuiltinFunctionsAndTypes(
      36            2 :                         ::googlesql::BuiltinFunctionOptions(language))
      37            2 :                     .ok());
      38            2 :     RegisterEmulatorMlTvfStubs(*catalog_);
      39            2 :   }
      40              : 
      41              :   std::unique_ptr<::googlesql::TypeFactory> type_factory_{};
      42              :   std::unique_ptr<EmulatorMlTestCatalog> catalog_{};
      43              : };
      44              : 
      45              : // Mirrors `GoogleSqlCatalog::FindTableValuedFunction` wiring so regressions
      46              : // in the unqualified fallback helper cannot stack-overflow the engine.
      47              : class CatalogWithTvfFallbackOverride : public ::googlesql::SimpleCatalog {
      48              :  public:
      49              :   using ::googlesql::SimpleCatalog::SimpleCatalog;
      50              : 
      51              :   absl::Status FindTableValuedFunction(
      52              :       const absl::Span<const std::string>& path,
      53              :       const ::googlesql::TableValuedFunction** function,
      54            1 :       const FindOptions& options = FindOptions()) override {
      55            1 :     return FindTableValuedFunctionWithUnqualifiedFallback(
      56            1 :         *this, path, function, options);
      57            1 :   }
      58              : };
      59              : 
      60            1 : TEST_F(EmulatorMlTvfCatalogTest, RegistersLookupPathsForMlPredict) {
      61            1 :   const ::googlesql::TableValuedFunction* tvf = nullptr;
      62            1 :   for (const std::vector<std::string> path :
      63            1 :        {std::vector<std::string>{"ML", "PREDICT"},
      64            1 :         std::vector<std::string>{"ml", "predict"}}) {
      65            1 :     SCOPED_TRACE(path[0] + "." + path[1]);
      66            1 :     absl::Status st = catalog_->FindTableValuedFunction(path, &tvf);
      67            1 :     if (st.ok()) {
      68            1 :       ASSERT_NE(tvf, nullptr);
      69            1 :       EXPECT_EQ(tvf->SQLName(), "ML.PREDICT");
      70            1 :       return;
      71            1 :     }
      72            1 :   }
      73            0 :   FAIL() << "FindTableValuedFunction did not resolve ML.PREDICT";
      74            1 : }
      75              : 
      76            1 : TEST(EmulatorMlTvfExtensionsTest, TvfFallbackOverrideDoesNotRecurse) {
      77            1 :   ::googlesql::TypeFactory type_factory;
      78            1 :   CatalogWithTvfFallbackOverride catalog("test_catalog", &type_factory);
      79            1 :   ::googlesql::LanguageOptions language;
      80            1 :   language.EnableMaximumLanguageFeaturesForDevelopment();
      81            1 :   language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
      82            1 :   language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
      83            1 :   ASSERT_TRUE(catalog
      84            1 :                   .AddBuiltinFunctionsAndTypes(
      85            1 :                       ::googlesql::BuiltinFunctionOptions(language))
      86            1 :                   .ok());
      87            1 :   RegisterEmulatorMlTvfStubs(catalog);
      88              : 
      89            1 :   const ::googlesql::TableValuedFunction* tvf = nullptr;
      90            1 :   ASSERT_TRUE(catalog.FindTableValuedFunction({"ML", "PREDICT"}, &tvf).ok());
      91            1 :   ASSERT_NE(tvf, nullptr);
      92            1 :   EXPECT_EQ(tvf->SQLName(), "ML.PREDICT");
      93            1 : }
      94              : 
      95            1 : TEST_F(EmulatorMlTvfCatalogTest, AnalyzeMlPredictSucceeds) {
      96            1 :   ::googlesql::LanguageOptions language;
      97            1 :   language.EnableMaximumLanguageFeaturesForDevelopment();
      98            1 :   language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
      99            1 :   language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
     100            1 :   language.set_name_resolution_mode(::googlesql::NAME_RESOLUTION_DEFAULT);
     101            1 :   language.SetSupportsAllStatementKinds();
     102            1 :   ::googlesql::AnalyzerOptions options(language);
     103            1 :   options.set_error_message_mode(::googlesql::ERROR_MESSAGE_ONE_LINE);
     104            1 :   options.CreateDefaultArenasIfNotSet();
     105              : 
     106            1 :   std::unique_ptr<const ::googlesql::AnalyzerOutput> output;
     107            1 :   absl::Status st = ::googlesql::AnalyzeStatement(
     108            1 :       "SELECT * FROM ML.PREDICT(MODEL `ds.unregistered_model`, "
     109            1 :       "(SELECT 1.0 AS f1))",
     110            1 :       options,
     111            1 :       catalog_.get(),
     112            1 :       type_factory_.get(),
     113            1 :       &output);
     114            2 :   EXPECT_TRUE(st.ok()) << st;
     115            1 : }
     116              : 
     117              : }  // namespace
     118              : }  // namespace catalog
     119              : }  // namespace backend
     120              : }  // namespace bigquery_emulator
        

Generated by: LCOV version 2.0-1