Line data Source code
1 : #include "backend/catalog/python_udf_registry.h"
2 :
3 : #include "gtest/gtest.h"
4 :
5 : namespace bigquery_emulator {
6 : namespace backend {
7 : namespace catalog {
8 : namespace {
9 :
10 1 : TEST(PythonUdfRegistryTest, ParsePythonUdfFromDdlExtractsPackages) {
11 1 : const char* ddl = R"(
12 1 : CREATE FUNCTION py_lxml(x STRING) RETURNS STRING
13 1 : LANGUAGE python
14 1 : OPTIONS (entry_point='do_lxml', packages=['lxml'])
15 1 : AS r"""
16 1 : from lxml import etree
17 1 : def do_lxml(x):
18 1 : return x
19 1 : """)";
20 1 : absl::StatusOr<PythonUdfDefinition> def_or =
21 1 : ParsePythonUdfFromDdl(ddl, "py_lxml");
22 2 : ASSERT_TRUE(def_or.ok()) << def_or.status();
23 1 : ASSERT_EQ(def_or->packages.size(), 1u);
24 1 : EXPECT_EQ(def_or->packages[0], "lxml");
25 1 : EXPECT_EQ(def_or->entry_point, "do_lxml");
26 1 : }
27 :
28 1 : TEST(PythonUdfRegistryTest, ParsePythonUdfFromDdlRejectsAggregate) {
29 1 : const char* ddl = R"(
30 1 : CREATE AGGREGATE FUNCTION weighted_avg(x FLOAT64, w FLOAT64)
31 1 : RETURNS FLOAT64
32 1 : LANGUAGE python
33 1 : OPTIONS (entry_point='weighted_avg')
34 1 : AS R"""
35 1 : def weighted_avg(x, w):
36 1 : return sum(x * w) / sum(w)
37 1 : """)";
38 1 : absl::StatusOr<PythonUdfDefinition> def_or =
39 1 : ParsePythonUdfFromDdl(ddl, "weighted_avg");
40 1 : ASSERT_FALSE(def_or.ok());
41 1 : EXPECT_EQ(def_or.status().message(),
42 1 : "CREATE AGGREGATE FUNCTION with language python is not supported");
43 1 : }
44 :
45 : } // namespace
46 : } // namespace catalog
47 : } // namespace backend
48 : } // namespace bigquery_emulator
|