Line data Source code
1 : #include "backend/engine/semantic/stubs/ml.h"
2 :
3 : #include <memory>
4 : #include <string>
5 : #include <vector>
6 :
7 : #include "absl/status/status.h"
8 : #include "absl/status/statusor.h"
9 : #include "backend/catalog/emulator_ml_test_catalog.h"
10 : #include "backend/catalog/emulator_ml_tvf_extensions.h"
11 : #include "backend/engine/semantic/eval_tvf.h"
12 : #include "backend/engine/semantic/value.h"
13 : #include "googlesql/public/analyzer.h"
14 : #include "googlesql/public/analyzer_options.h"
15 : #include "googlesql/public/analyzer_output.h"
16 : #include "googlesql/public/builtin_function_options.h"
17 : #include "googlesql/public/catalog.h"
18 : #include "googlesql/public/language_options.h"
19 : #include "googlesql/public/options.pb.h"
20 : #include "googlesql/public/types/type_factory.h"
21 : #include "googlesql/resolved_ast/resolved_ast.h"
22 : #include "googlesql/resolved_ast/resolved_ast_visitor.h"
23 : #include "googlesql/resolved_ast/resolved_node_kind.pb.h"
24 : #include "gtest/gtest.h"
25 :
26 : namespace bigquery_emulator {
27 : namespace backend {
28 : namespace engine {
29 : namespace semantic {
30 : namespace stubs {
31 : namespace {
32 :
33 3 : void ExpectFloatingPointNear(const Value& value, double expected) {
34 3 : ASSERT_FALSE(value.is_null());
35 3 : if (value.type()->IsDouble()) {
36 3 : EXPECT_DOUBLE_EQ(value.double_value(), expected);
37 3 : } else {
38 0 : EXPECT_FLOAT_EQ(value.float_value(), static_cast<float>(expected));
39 0 : }
40 3 : }
41 :
42 : class TvfScanFinder : public ::googlesql::ResolvedASTVisitor {
43 : public:
44 : const ::googlesql::ResolvedTVFScan* found = nullptr;
45 :
46 : absl::Status VisitResolvedTVFScan(
47 4 : const ::googlesql::ResolvedTVFScan* node) override {
48 4 : found = node;
49 4 : return absl::OkStatus();
50 4 : }
51 :
52 4 : absl::Status DefaultVisit(const ::googlesql::ResolvedNode* node) override {
53 4 : return ::googlesql::ResolvedASTVisitor::DefaultVisit(node);
54 4 : }
55 : };
56 :
57 4 : ::googlesql::AnalyzerOptions MakeAnalyzerOptions() {
58 4 : ::googlesql::LanguageOptions language;
59 4 : language.EnableMaximumLanguageFeaturesForDevelopment();
60 4 : language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
61 4 : language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
62 4 : language.set_name_resolution_mode(::googlesql::NAME_RESOLUTION_DEFAULT);
63 4 : language.SetSupportsAllStatementKinds();
64 4 : ::googlesql::AnalyzerOptions options(language);
65 4 : options.set_error_message_mode(::googlesql::ERROR_MESSAGE_ONE_LINE);
66 4 : options.disable_rewrite(::googlesql::REWRITE_PIVOT);
67 4 : options.disable_rewrite(::googlesql::REWRITE_UNPIVOT);
68 4 : options.CreateDefaultArenasIfNotSet();
69 4 : return options;
70 4 : }
71 :
72 : class MlStubTest : public ::testing::Test {
73 : protected:
74 4 : void SetUp() override {
75 4 : type_factory_ = std::make_unique<::googlesql::TypeFactory>();
76 4 : catalog_ = std::make_unique<catalog::EmulatorMlTestCatalog>(
77 4 : "test_catalog", type_factory_.get());
78 4 : ::googlesql::LanguageOptions language;
79 4 : language.EnableMaximumLanguageFeaturesForDevelopment();
80 4 : language.EnableLanguageFeature(::googlesql::FEATURE_REMOTE_MODEL);
81 4 : language.set_product_mode(::googlesql::PRODUCT_EXTERNAL);
82 4 : ASSERT_TRUE(catalog_
83 4 : ->AddBuiltinFunctionsAndTypes(
84 4 : ::googlesql::BuiltinFunctionOptions(language))
85 4 : .ok());
86 4 : catalog::RegisterEmulatorMlTvfStubs(*catalog_);
87 4 : }
88 :
89 4 : const ::googlesql::ResolvedTVFScan* AnalyzeTvfScan(absl::string_view sql) {
90 4 : last_output_.reset();
91 4 : ::googlesql::AnalyzerOptions options = MakeAnalyzerOptions();
92 4 : absl::Status s = ::googlesql::AnalyzeStatement(
93 4 : sql, options, catalog_.get(), type_factory_.get(), &last_output_);
94 8 : EXPECT_TRUE(s.ok()) << s;
95 4 : if (!s.ok() || last_output_ == nullptr) {
96 0 : return nullptr;
97 0 : }
98 4 : const ::googlesql::ResolvedStatement* stmt = nullptr;
99 4 : stmt = last_output_->resolved_statement();
100 4 : if (stmt == nullptr ||
101 4 : stmt->node_kind() != ::googlesql::RESOLVED_QUERY_STMT) {
102 0 : return nullptr;
103 0 : }
104 4 : const auto* query = stmt->GetAs<::googlesql::ResolvedQueryStmt>();
105 4 : TvfScanFinder finder;
106 4 : absl::Status walk = query->query()->Accept(&finder);
107 8 : EXPECT_TRUE(walk.ok()) << walk;
108 4 : return finder.found;
109 4 : }
110 :
111 : std::unique_ptr<::googlesql::TypeFactory> type_factory_{};
112 : std::unique_ptr<catalog::EmulatorMlTestCatalog> catalog_{};
113 : std::unique_ptr<const ::googlesql::AnalyzerOutput> last_output_{};
114 : };
115 :
116 1 : TEST_F(MlStubTest, PredictPassesThroughInputAndNullPredictedLabel) {
117 1 : const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
118 1 : "SELECT f1, label, predicted_label FROM ML.PREDICT("
119 1 : "MODEL `ds.unregistered_model`, "
120 1 : "(SELECT 2.0 AS f1, 3.0 AS label))");
121 1 : ASSERT_NE(tvf, nullptr);
122 :
123 1 : const ::googlesql::ResolvedScan* input_scan = nullptr;
124 2 : for (int i = 0; i < tvf->argument_list_size(); ++i) {
125 2 : const auto* arg = tvf->argument_list(i);
126 2 : if (arg != nullptr && arg->scan() != nullptr) {
127 1 : input_scan = arg->scan();
128 1 : break;
129 1 : }
130 2 : }
131 1 : ASSERT_NE(input_scan, nullptr);
132 1 : ASSERT_GE(input_scan->column_list_size(), 2);
133 :
134 1 : ColumnBindings input_row;
135 1 : const ::googlesql::ResolvedColumn& f1_col = input_scan->column_list(0);
136 1 : const ::googlesql::ResolvedColumn& label_col = input_scan->column_list(1);
137 1 : input_row.emplace(
138 1 : f1_col.column_id(),
139 1 : f1_col.type()->IsDouble() ? Value::Double(2.0) : Value::Float(2.0f));
140 1 : input_row.emplace(
141 1 : label_col.column_id(),
142 1 : label_col.type()->IsDouble() ? Value::Double(3.0) : Value::Float(3.0f));
143 :
144 1 : auto rows = MlPredictStub(*tvf, {input_row}, input_scan);
145 2 : ASSERT_TRUE(rows.ok()) << rows.status();
146 1 : ASSERT_EQ(rows->size(), 1u);
147 1 : const ColumnBindings& out = (*rows)[0];
148 1 : ASSERT_EQ(out.size(), 3u);
149 1 : bool saw_f1 = false;
150 1 : bool saw_label = false;
151 1 : bool saw_predicted = false;
152 4 : for (int i = 0; i < tvf->column_list_size(); ++i) {
153 3 : const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
154 3 : const auto it = out.find(col.column_id());
155 3 : ASSERT_NE(it, out.end());
156 3 : if (col.name() == "f1") {
157 1 : saw_f1 = true;
158 1 : ExpectFloatingPointNear(it->second, 2.0);
159 2 : } else if (col.name() == "label") {
160 1 : saw_label = true;
161 1 : ExpectFloatingPointNear(it->second, 3.0);
162 1 : } else if (col.name() == "predicted_label") {
163 1 : saw_predicted = true;
164 1 : EXPECT_TRUE(it->second.is_null());
165 1 : EXPECT_TRUE(it->second.type()->IsFloatingPoint());
166 1 : }
167 3 : }
168 1 : EXPECT_TRUE(saw_f1);
169 1 : EXPECT_TRUE(saw_label);
170 1 : EXPECT_TRUE(saw_predicted);
171 1 : }
172 :
173 1 : TEST_F(MlStubTest, EvaluateReturnsSingleNullMetricsRow) {
174 1 : const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
175 1 : "SELECT * FROM ML.EVALUATE(MODEL `ds.unregistered_model`)");
176 1 : ASSERT_NE(tvf, nullptr);
177 :
178 1 : auto rows = MlEvaluateStub(*tvf);
179 2 : ASSERT_TRUE(rows.ok()) << rows.status();
180 1 : ASSERT_EQ(rows->size(), 1u);
181 3 : for (int i = 0; i < tvf->column_list_size(); ++i) {
182 2 : const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
183 2 : const auto it = (*rows)[0].find(col.column_id());
184 2 : ASSERT_NE(it, (*rows)[0].end());
185 4 : EXPECT_TRUE(it->second.is_null()) << "column " << col.name();
186 2 : }
187 1 : }
188 :
189 1 : TEST_F(MlStubTest, ForecastReturnsSingleNullForecastRow) {
190 1 : const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
191 1 : "SELECT * FROM ML.FORECAST(MODEL `ds.unregistered_model`, "
192 1 : "STRUCT(7 AS horizon))");
193 1 : ASSERT_NE(tvf, nullptr);
194 :
195 1 : auto rows = MlForecastStub(*tvf);
196 2 : ASSERT_TRUE(rows.ok()) << rows.status();
197 1 : ASSERT_EQ(rows->size(), 1u);
198 3 : for (int i = 0; i < tvf->column_list_size(); ++i) {
199 2 : const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
200 2 : const auto it = (*rows)[0].find(col.column_id());
201 2 : ASSERT_NE(it, (*rows)[0].end());
202 4 : EXPECT_TRUE(it->second.is_null()) << "column " << col.name();
203 2 : }
204 1 : }
205 :
206 1 : TEST_F(MlStubTest, MaterializeTvfScanPredictEndToEnd) {
207 1 : const ::googlesql::ResolvedTVFScan* tvf = AnalyzeTvfScan(
208 1 : "SELECT f1, predicted_label FROM ML.PREDICT("
209 1 : "MODEL `ds.unregistered_model`, (SELECT 4.0 AS f1))");
210 1 : ASSERT_NE(tvf, nullptr);
211 :
212 1 : EvalContext ctx;
213 1 : auto rows = MaterializeTvfScan(*tvf, ctx);
214 2 : ASSERT_TRUE(rows.ok()) << rows.status();
215 1 : ASSERT_EQ(rows->size(), 1u);
216 1 : bool saw_f1 = false;
217 1 : bool saw_predicted = false;
218 3 : for (int i = 0; i < tvf->column_list_size(); ++i) {
219 2 : const ::googlesql::ResolvedColumn& col = tvf->column_list(i);
220 2 : const auto it = (*rows)[0].find(col.column_id());
221 2 : ASSERT_NE(it, (*rows)[0].end());
222 2 : if (col.name() == "f1") {
223 1 : saw_f1 = true;
224 1 : ExpectFloatingPointNear(it->second, 4.0);
225 1 : } else if (col.name() == "predicted_label") {
226 1 : saw_predicted = true;
227 1 : EXPECT_TRUE(it->second.is_null());
228 1 : }
229 2 : }
230 1 : EXPECT_TRUE(saw_f1);
231 1 : EXPECT_TRUE(saw_predicted);
232 1 : }
233 :
234 : } // namespace
235 : } // namespace stubs
236 : } // namespace semantic
237 : } // namespace engine
238 : } // namespace backend
239 : } // namespace bigquery_emulator
|