LCOV - code coverage report
Current view: top level - backend/engine/semantic/stubs - ml_test.cc (source / functions) Coverage Total Hit
Test: _coverage_report.dat Lines: 96.5 % 172 166
Test Date: 2026-07-02 21:01:18 Functions: 100.0 % 10 10

            Line data    Source code
       1              : #include "backend/engine/semantic/stubs/ml.h"
       2              : 
       3              : #include <memory>
       4              : #include <string>
       5              : #include <vector>
       6              : 
       7              : #include "absl/status/status.h"
       8              : #include "absl/status/statusor.h"
       9              : #include "backend/catalog/emulator_ml_test_catalog.h"
      10              : #include "backend/catalog/emulator_ml_tvf_extensions.h"
      11              : #include "backend/engine/semantic/eval_tvf.h"
      12              : #include "backend/engine/semantic/value.h"
      13              : #include "googlesql/public/analyzer.h"
      14              : #include "googlesql/public/analyzer_options.h"
      15              : #include "googlesql/public/analyzer_output.h"
      16              : #include "googlesql/public/builtin_function_options.h"
      17              : #include "googlesql/public/catalog.h"
      18              : #include "googlesql/public/language_options.h"
      19              : #include "googlesql/public/options.pb.h"
      20              : #include "googlesql/public/types/type_factory.h"
      21              : #include "googlesql/resolved_ast/resolved_ast.h"
      22              : #include "googlesql/resolved_ast/resolved_ast_visitor.h"
      23              : #include "googlesql/resolved_ast/resolved_node_kind.pb.h"
      24              : #include "gtest/gtest.h"
      25              : 
      26              : namespace bigquery_emulator {
      27              : namespace backend {
      28              : namespace engine {
      29              : namespace semantic {
      30              : namespace stubs {
      31              : namespace {
      32              : 
      33            3 : void ExpectFloatingPointNear(const Value& value, double expected) {
      34            3 :   ASSERT_FALSE(value.is_null());
      35            3 :   if (value.type()->IsDouble()) {
      36            3 :     EXPECT_DOUBLE_EQ(value.double_value(), expected);
      37            3 :   } else {
      38            0 :     EXPECT_FLOAT_EQ(value.float_value(), static_cast<float>(expected));
      39            0 :   }
      40            3 : }
      41              : 
      42              : class TvfScanFinder : public ::googlesql::ResolvedASTVisitor {
      43              :  public:
      44              :   const ::googlesql::ResolvedTVFScan* found = nullptr;
      45              : 
      46              :   absl::Status VisitResolvedTVFScan(
      47            4 :       const ::googlesql::ResolvedTVFScan* node) override {
      48            4 :     found = node;
      49            4 :     return absl::OkStatus();
      50            4 :   }
      51              : 
      52            4 :   absl::Status DefaultVisit(const ::googlesql::ResolvedNode* node) override {
      53            4 :     return ::googlesql::ResolvedASTVisitor::DefaultVisit(node);
      54            4 :   }
      55              : };
      56              : 
      57            4 : ::googlesql::AnalyzerOptions MakeAnalyzerOptions() {
      58            4 :   ::googlesql::LanguageOptions language;
      59            4 :   language.EnableMaximumLanguageFeaturesForDevelopment();
      60            4 :   language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
      61            4 :   language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
      62            4 :   language.set_name_resolution_mode(::googlesql::NAME_RESOLUTION_DEFAULT);
      63            4 :   language.SetSupportsAllStatementKinds();
      64            4 :   ::googlesql::AnalyzerOptions options(language);
      65            4 :   options.set_error_message_mode(::googlesql::ERROR_MESSAGE_ONE_LINE);
      66            4 :   options.disable_rewrite(::googlesql::REWRITE_PIVOT);
      67            4 :   options.disable_rewrite(::googlesql::REWRITE_UNPIVOT);
      68            4 :   options.CreateDefaultArenasIfNotSet();
      69            4 :   return options;
      70            4 : }
      71              : 
      72              : class MlStubTest : public ::testing::Test {
      73              :  protected:
      74            4 :   void SetUp() override {
      75            4 :     type_factory_ = std::make_unique<::googlesql::TypeFactory>();
      76            4 :     catalog_ = std::make_unique<catalog::EmulatorMlTestCatalog>(
      77            4 :         "test_catalog", type_factory_.get());
      78            4 :     ::googlesql::LanguageOptions language;
      79            4 :     language.EnableMaximumLanguageFeaturesForDevelopment();
      80            4 :     language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
      81            4 :     language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
      82            4 :     ASSERT_TRUE(catalog_
      83            4 :                     ->AddBuiltinFunctionsAndTypes(
      84            4 :                         ::googlesql::BuiltinFunctionOptions(language))
      85            4 :                     .ok());
      86            4 :     catalog::RegisterEmulatorMlTvfStubs(*catalog_);
      87            4 :   }
      88              : 
      89            4 :   const ::googlesql::ResolvedTVFScan* AnalyzeTvfScan(absl::string_view sql) {
      90            4 :     last_output_.reset();
      91            4 :     ::googlesql::AnalyzerOptions options = MakeAnalyzerOptions();
      92            4 :     absl::Status s = ::googlesql::AnalyzeStatement(
      93            4 :         sql, options, catalog_.get(), type_factory_.get(), &last_output_);
      94            8 :     EXPECT_TRUE(s.ok()) << s;
      95            4 :     if (!s.ok() || last_output_ == nullptr) {
      96            0 :       return nullptr;
      97            0 :     }
      98            4 :     const ::googlesql::ResolvedStatement* stmt = nullptr;
      99            4 :     stmt = last_output_->resolved_statement();
     100            4 :     if (stmt == nullptr ||
     101            4 :         stmt->node_kind() != ::googlesql::RESOLVED_QUERY_STMT) {
     102            0 :       return nullptr;
     103            0 :     }
     104            4 :     const auto* query = stmt->GetAs<::googlesql::ResolvedQueryStmt>();
     105            4 :     TvfScanFinder finder;
     106            4 :     absl::Status walk = query->query()->Accept(&finder);
     107            8 :     EXPECT_TRUE(walk.ok()) << walk;
     108            4 :     return finder.found;
     109            4 :   }
     110              : 
     111              :   std::unique_ptr<::googlesql::TypeFactory> type_factory_{};
     112              :   std::unique_ptr<catalog::EmulatorMlTestCatalog> catalog_{};
     113              :   std::unique_ptr<const ::googlesql::AnalyzerOutput> last_output_{};
     114              : };
     115              : 
     116            1 : TEST_F(MlStubTest, PredictPassesThroughInputAndNullPredictedLabel) {
     117            1 :   const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
     118            1 :       "SELECT f1, label, predicted_label FROM ML.PREDICT("
     119            1 :       "MODEL `ds.unregistered_model`, "
     120            1 :       "(SELECT 2.0 AS f1, 3.0 AS label))");
     121            1 :   ASSERT_NE(tvf, nullptr);
     122              : 
     123            1 :   const ::googlesql::ResolvedScan* input_scan = nullptr;
     124            2 :   for (int i = 0; i < tvf->argument_list_size(); ++i) {
     125            2 :     const auto* arg = tvf->argument_list(i);
     126            2 :     if (arg != nullptr && arg->scan() != nullptr) {
     127            1 :       input_scan = arg->scan();
     128            1 :       break;
     129            1 :     }
     130            2 :   }
     131            1 :   ASSERT_NE(input_scan, nullptr);
     132            1 :   ASSERT_GE(input_scan->column_list_size(), 2);
     133              : 
     134            1 :   ColumnBindings input_row;
     135            1 :   const ::googlesql::ResolvedColumn& f1_col = input_scan->column_list(0);
     136            1 :   const ::googlesql::ResolvedColumn& label_col = input_scan->column_list(1);
     137            1 :   input_row.emplace(
     138            1 :       f1_col.column_id(),
     139            1 :       f1_col.type()->IsDouble() ? Value::Double(2.0) : Value::Float(2.0f));
     140            1 :   input_row.emplace(
     141            1 :       label_col.column_id(),
     142            1 :       label_col.type()->IsDouble() ? Value::Double(3.0) : Value::Float(3.0f));
     143              : 
     144            1 :   auto rows = MlPredictStub(*tvf, {input_row}, input_scan);
     145            2 :   ASSERT_TRUE(rows.ok()) << rows.status();
     146            1 :   ASSERT_EQ(rows->size(), 1u);
     147            1 :   const ColumnBindings& out = (*rows)[0];
     148            1 :   ASSERT_EQ(out.size(), 3u);
     149            1 :   bool saw_f1 = false;
     150            1 :   bool saw_label = false;
     151            1 :   bool saw_predicted = false;
     152            4 :   for (int i = 0; i < tvf->column_list_size(); ++i) {
     153            3 :     const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
     154            3 :     const auto it = out.find(col.column_id());
     155            3 :     ASSERT_NE(it, out.end());
     156            3 :     if (col.name() == "f1") {
     157            1 :       saw_f1 = true;
     158            1 :       ExpectFloatingPointNear(it->second, 2.0);
     159            2 :     } else if (col.name() == "label") {
     160            1 :       saw_label = true;
     161            1 :       ExpectFloatingPointNear(it->second, 3.0);
     162            1 :     } else if (col.name() == "predicted_label") {
     163            1 :       saw_predicted = true;
     164            1 :       EXPECT_TRUE(it->second.is_null());
     165            1 :       EXPECT_TRUE(it->second.type()->IsFloatingPoint());
     166            1 :     }
     167            3 :   }
     168            1 :   EXPECT_TRUE(saw_f1);
     169            1 :   EXPECT_TRUE(saw_label);
     170            1 :   EXPECT_TRUE(saw_predicted);
     171            1 : }
     172              : 
     173            1 : TEST_F(MlStubTest, EvaluateReturnsSingleNullMetricsRow) {
     174            1 :   const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
     175            1 :       "SELECT * FROM ML.EVALUATE(MODEL `ds.unregistered_model`)");
     176            1 :   ASSERT_NE(tvf, nullptr);
     177              : 
     178            1 :   auto rows = MlEvaluateStub(*tvf);
     179            2 :   ASSERT_TRUE(rows.ok()) << rows.status();
     180            1 :   ASSERT_EQ(rows->size(), 1u);
     181            3 :   for (int i = 0; i < tvf->column_list_size(); ++i) {
     182            2 :     const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
     183            2 :     const auto it = (*rows)[0].find(col.column_id());
     184            2 :     ASSERT_NE(it, (*rows)[0].end());
     185            4 :     EXPECT_TRUE(it->second.is_null()) << "column " << col.name();
     186            2 :   }
     187            1 : }
     188              : 
     189            1 : TEST_F(MlStubTest, ForecastReturnsSingleNullForecastRow) {
     190            1 :   const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
     191            1 :       "SELECT * FROM ML.FORECAST(MODEL `ds.unregistered_model`, "
     192            1 :       "STRUCT(7 AS horizon))");
     193            1 :   ASSERT_NE(tvf, nullptr);
     194              : 
     195            1 :   auto rows = MlForecastStub(*tvf);
     196            2 :   ASSERT_TRUE(rows.ok()) << rows.status();
     197            1 :   ASSERT_EQ(rows->size(), 1u);
     198            3 :   for (int i = 0; i < tvf->column_list_size(); ++i) {
     199            2 :     const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
     200            2 :     const auto it = (*rows)[0].find(col.column_id());
     201            2 :     ASSERT_NE(it, (*rows)[0].end());
     202            4 :     EXPECT_TRUE(it->second.is_null()) << "column " << col.name();
     203            2 :   }
     204            1 : }
     205              : 
     206            1 : TEST_F(MlStubTest, MaterializeTvfScanPredictEndToEnd) {
     207            1 :   const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
     208            1 :       "SELECT f1, predicted_label FROM ML.PREDICT("
     209            1 :       "MODEL `ds.unregistered_model`, (SELECT 4.0 AS f1))");
     210            1 :   ASSERT_NE(tvf, nullptr);
     211              : 
     212            1 :   EvalContext ctx;
     213            1 :   auto rows = MaterializeTvfScan(*tvf, ctx);
     214            2 :   ASSERT_TRUE(rows.ok()) << rows.status();
     215            1 :   ASSERT_EQ(rows->size(), 1u);
     216            1 :   bool saw_f1 = false;
     217            1 :   bool saw_predicted = false;
     218            3 :   for (int i = 0; i < tvf->column_list_size(); ++i) {
     219            2 :     const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
     220            2 :     const auto it = (*rows)[0].find(col.column_id());
     221            2 :     ASSERT_NE(it, (*rows)[0].end());
     222            2 :     if (col.name() == "f1") {
     223            1 :       saw_f1 = true;
     224            1 :       ExpectFloatingPointNear(it->second, 4.0);
     225            1 :     } else if (col.name() == "predicted_label") {
     226            1 :       saw_predicted = true;
     227            1 :       EXPECT_TRUE(it->second.is_null());
     228            1 :     }
     229            2 :   }
     230            1 :   EXPECT_TRUE(saw_f1);
     231            1 :   EXPECT_TRUE(saw_predicted);
     232            1 : }
     233              : 
     234              : }  // namespace
     235              : }  // namespace stubs
     236              : }  // namespace semantic
     237              : }  // namespace engine
     238              : }  // namespace backend
     239              : }  // namespace bigquery_emulator
        

Generated by: LCOV version 2.0-1