Line data Source code
1 : #include "backend/catalog/emulator_ml_tvf_extensions.h"
2 :
3 : #include <memory>
4 : #include <string>
5 : #include <vector>
6 :
7 : #include "absl/status/status.h"
8 : #include "absl/strings/str_join.h"
9 : #include "backend/catalog/emulator_ml_test_catalog.h"
10 : #include "googlesql/public/analyzer.h"
11 : #include "googlesql/public/analyzer_options.h"
12 : #include "googlesql/public/builtin_function_options.h"
13 : #include "googlesql/public/catalog.h"
14 : #include "googlesql/public/language_options.h"
15 : #include "googlesql/public/options.pb.h"
16 : #include "googlesql/public/types/type_factory.h"
17 : #include "gtest/gtest.h"
18 :
19 : namespace bigquery_emulator {
20 : namespace backend {
21 : namespace catalog {
22 : namespace {
23 :
24 : class EmulatorMlTvfCatalogTest : public ::testing::Test {
25 : protected:
26 2 : void SetUp() override {
27 2 : type_factory_ = std::make_unique<::googlesql::TypeFactory>();
28 2 : catalog_ = std::make_unique<EmulatorMlTestCatalog>("test_catalog",
29 2 : type_factory_.get());
30 2 : ::googlesql::LanguageOptions language;
31 2 : language.EnableMaximumLanguageFeaturesForDevelopment();
32 2 : language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
33 2 : language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
34 2 : ASSERT_TRUE(catalog_
35 2 : ->AddBuiltinFunctionsAndTypes(
36 2 : ::googlesql::BuiltinFunctionOptions(language))
37 2 : .ok());
38 2 : RegisterEmulatorMlTvfStubs(*catalog_);
39 2 : }
40 :
41 : std::unique_ptr<::googlesql::TypeFactory> type_factory_{};
42 : std::unique_ptr<EmulatorMlTestCatalog> catalog_{};
43 : };
44 :
45 : // Mirrors `GoogleSqlCatalog::FindTableValuedFunction` wiring so regressions
46 : // in the unqualified fallback helper cannot stack-overflow the engine.
47 : class CatalogWithTvfFallbackOverride : public ::googlesql::SimpleCatalog {
48 : public:
49 : using ::googlesql::SimpleCatalog::SimpleCatalog;
50 :
51 : absl::Status FindTableValuedFunction(
52 : const absl::Span<const std::string>& path,
53 : const ::googlesql::TableValuedFunction** function,
54 1 : const FindOptions& options = FindOptions()) override {
55 1 : return FindTableValuedFunctionWithUnqualifiedFallback(
56 1 : *this, path, function, options);
57 1 : }
58 : };
59 :
60 1 : TEST_F(EmulatorMlTvfCatalogTest, RegistersLookupPathsForMlPredict) {
61 1 : const ::googlesql::TableValuedFunction* tvf = nullptr;
62 1 : for (const std::vector<std::string> path :
63 1 : {std::vector<std::string>{"ML", "PREDICT"},
64 1 : std::vector<std::string>{"ml", "predict"}}) {
65 1 : SCOPED_TRACE(path[0] + "." + path[1]);
66 1 : absl::Status st = catalog_->FindTableValuedFunction(path, &tvf);
67 1 : if (st.ok()) {
68 1 : ASSERT_NE(tvf, nullptr);
69 1 : EXPECT_EQ(tvf->SQLName(), "ML.PREDICT");
70 1 : return;
71 1 : }
72 1 : }
73 0 : FAIL() << "FindTableValuedFunction did not resolve ML.PREDICT";
74 1 : }
75 :
76 1 : TEST(EmulatorMlTvfExtensionsTest, TvfFallbackOverrideDoesNotRecurse) {
77 1 : ::googlesql::TypeFactory type_factory;
78 1 : CatalogWithTvfFallbackOverride catalog("test_catalog", &type_factory);
79 1 : ::googlesql::LanguageOptions language;
80 1 : language.EnableMaximumLanguageFeaturesForDevelopment();
81 1 : language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
82 1 : language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
83 1 : ASSERT_TRUE(catalog
84 1 : .AddBuiltinFunctionsAndTypes(
85 1 : ::googlesql::BuiltinFunctionOptions(language))
86 1 : .ok());
87 1 : RegisterEmulatorMlTvfStubs(catalog);
88 :
89 1 : const ::googlesql::TableValuedFunction* tvf = nullptr;
90 1 : ASSERT_TRUE(catalog.FindTableValuedFunction({"ML", "PREDICT"}, &tvf).ok());
91 1 : ASSERT_NE(tvf, nullptr);
92 1 : EXPECT_EQ(tvf->SQLName(), "ML.PREDICT");
93 1 : }
94 :
95 1 : TEST_F(EmulatorMlTvfCatalogTest, AnalyzeMlPredictSucceeds) {
96 1 : ::googlesql::LanguageOptions language;
97 1 : language.EnableMaximumLanguageFeaturesForDevelopment();
98 1 : language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
99 1 : language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
100 1 : language.set_name_resolution_mode(::googlesql::NAME_RESOLUTION_DEFAULT);
101 1 : language.SetSupportsAllStatementKinds();
102 1 : ::googlesql::AnalyzerOptions options(language);
103 1 : options.set_error_message_mode(::googlesql::ERROR_MESSAGE_ONE_LINE);
104 1 : options.CreateDefaultArenasIfNotSet();
105 :
106 1 : std::unique_ptr<const ::googlesql::AnalyzerOutput> output;
107 1 : absl::Status st = ::googlesql::AnalyzeStatement(
108 1 : "SELECT * FROM ML.PREDICT(MODEL `ds.unregistered_model`, "
109 1 : "(SELECT 1.0 AS f1))",
110 1 : options,
111 1 : catalog_.get(),
112 1 : type_factory_.get(),
113 1 : &output);
114 2 : EXPECT_TRUE(st.ok()) << st;
115 1 : }
116 :
117 : } // namespace
118 : } // namespace catalog
119 : } // namespace backend
120 : } // namespace bigquery_emulator
|