Skip to content
This repository was archived by the owner on Oct 15, 2025. It is now read-only.

Commit f991a9f

Browse files
committed
Simplify function interface
1 parent f6f3dda commit f991a9f

File tree

5 files changed

+218
-1
lines changed

5 files changed

+218
-1
lines changed

evadb/functions/My_SimpleUDF.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def My_SimpleUDF(cls, x:int)->int:
2+
return x + 5

evadb/functions/simple_udf.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# coding=utf-8
2+
# Copyright 2018-2023 EvaDB
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import numpy as np
16+
import pandas as pd
17+
import importlib
18+
import pickle
19+
from pathlib import Path
20+
import typing
21+
22+
from evadb.catalog.catalog_type import NdArrayType
23+
from evadb.functions.abstract.abstract_function import AbstractFunction
24+
from evadb.functions.decorators.decorators import forward, setup
25+
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
26+
from evadb.configuration.constants import EvaDB_ROOT_DIR
27+
28+
class SimpleUDF(AbstractFunction):
29+
@setup(cacheable=False, function_type="SimpleUDF", batchable=False)
30+
def setup(self):
31+
in_labels = []
32+
in_types = []
33+
for label in self.types:
34+
if label == "return": continue
35+
in_labels.append(label)
36+
in_types.append(self.convert_python_types(self.types[label]))
37+
out_types = [self.convert_python_types(self.types['return'])]
38+
39+
self.forward.tags["input"] = [PandasDataframe(
40+
columns=in_labels,
41+
column_types=in_types,
42+
column_shapes=[(1) * len(in_labels)]
43+
)]
44+
45+
self.forward.tags["output"] = [PandasDataframe(
46+
columns=["output"],
47+
column_types=out_types,
48+
column_shapes=[(1) * len(out_types)],
49+
)]
50+
51+
@property
52+
def name(self) -> str:
53+
return "SimpleUDF"
54+
55+
@forward(None, None)
56+
def forward(self, df: pd.DataFrame) -> pd.DataFrame:
57+
def _forward(row: pd.Series) -> np.ndarray:
58+
temp = self.udf
59+
return temp(row)
60+
61+
ret = pd.DataFrame()
62+
ret["output"] = df.apply(_forward, axis=1)
63+
return ret
64+
65+
def set_udf(self, classname:str, filepath: str):
66+
if f"{EvaDB_ROOT_DIR}/simple_udfs/" in filepath:
67+
f = open(f"{EvaDB_ROOT_DIR}/simple_udfs/Func_SimpleUDF", 'rb')
68+
self.udf = pickle.load(f)
69+
else:
70+
try:
71+
abs_path = Path(filepath).resolve()
72+
spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path)
73+
module = importlib.util.module_from_spec(spec)
74+
spec.loader.exec_module(module)
75+
except ImportError as e:
76+
# ImportError in the case when we are able to find the file but not able to load the module
77+
err_msg = f"ImportError : Couldn't load function from {filepath} : {str(e)}. Not able to load the code provided in the file {abs_path}. Please ensure that the file contains the implementation code for the function."
78+
raise ImportError(err_msg)
79+
except FileNotFoundError as e:
80+
# FileNotFoundError in the case when we are not able to find the file at all at the path.
81+
err_msg = f"FileNotFoundError : Couldn't load function from {filepath} : {str(e)}. This might be because the function implementation file does not exist. Please ensure the file exists at {abs_path}"
82+
raise FileNotFoundError(err_msg)
83+
except Exception as e:
84+
# Default exception, we don't know what exactly went wrong so we just output the error message
85+
err_msg = f"Couldn't load function from {filepath} : {str(e)}."
86+
raise RuntimeError(err_msg)
87+
88+
# Try to load the specified class by name
89+
if classname and hasattr(module, classname):
90+
self.udf = getattr(module, classname)
91+
92+
self.types = typing.get_type_hints(self.udf)
93+
94+
def convert_python_types(self, type):
95+
if type == bool:
96+
return NdArrayType.BOOL
97+
elif type == int:
98+
return NdArrayType.INT32
99+
elif type == float:
100+
return NdArrayType.FLOAT32
101+
elif type == str:
102+
return NdArrayType.STR
103+
else:
104+
return NdArrayType.ANYTYPE

evadb/interfaces/relational/db.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import multiprocessing
1717

1818
import pandas
19+
import pickle
1920

20-
from evadb.configuration.constants import EvaDB_DATABASE_DIR
21+
from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_ROOT_DIR
2122
from evadb.database import EvaDBDatabase, init_evadb_instance
2223
from evadb.expression.tuple_value_expression import TupleValueExpression
2324
from evadb.functions.function_bootstrap_queries import init_builtin_functions
@@ -413,6 +414,30 @@ def create_function(
413414
function_name, if_not_exists, impl_path, type, **kwargs
414415
)
415416
return EvaDBQuery(self._evadb, stmt)
417+
418+
def create_simple_function(
419+
self,
420+
function_name: str,
421+
function: callable,
422+
if_not_exists: bool = True,
423+
) -> "EvaDBQuery":
424+
"""
425+
Create a function in the database by passing in a function instance.
426+
427+
Args:
428+
function_name (str): Name of the function to be created.
429+
if_not_exists (bool): If True, do not raise an error if the function already exist. If False, raise an error.
430+
function (callable): The function instance
431+
432+
Returns:
433+
EvaDBQuery: The EvaDBQuery object representing the function created.
434+
"""
435+
impl_path = f"{EvaDB_ROOT_DIR}/simple_udfs/{function_name}"
436+
f = open(impl_path, 'ab')
437+
pickle.dump(function, f)
438+
f.close()
439+
440+
return self.create_function(function_name, if_not_exists, impl_path)
416441

417442
def create_table(
418443
self, table_name: str, if_not_exists: bool = True, columns: str = None, **kwargs

evadb/utils/generic_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from evadb.configuration.constants import EvaDB_INSTALLATION_DIR
3030
from evadb.utils.logging_manager import logger
31+
from evadb.configuration.constants import EvaDB_ROOT_DIR
3132

3233

3334
def validate_kwargs(
@@ -79,6 +80,14 @@ def load_function_class_from_file(filepath, classname=None):
7980
FileNotFoundError: If the file cannot be found.
8081
RuntimeError: Any othe type of runtime error.
8182
"""
83+
simple_udf_filepath = None
84+
simple_udf_classname = None
85+
if classname and "_SimpleUDF" in classname:
86+
simple_udf_classname = classname
87+
classname = "SimpleUDF"
88+
simple_udf_filepath = filepath
89+
filepath = f"{EvaDB_ROOT_DIR}/evadb/functions/simple_udf.py"
90+
8291
try:
8392
abs_path = Path(filepath).resolve()
8493
spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path)
@@ -99,6 +108,10 @@ def load_function_class_from_file(filepath, classname=None):
99108

100109
# Try to load the specified class by name
101110
if classname and hasattr(module, classname):
111+
if classname == "SimpleUDF":
112+
cls = getattr(module, classname)
113+
cls.set_udf(cls, simple_udf_classname, simple_udf_filepath)
114+
return cls
102115
return getattr(module, classname)
103116

104117
# If class name not specified, check if there is only one class in the file
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# coding=utf-8
2+
# Copyright 2018-2023 EvaDB
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import unittest
16+
from test.util import suffix_pytest_xdist_worker_id_to_dir
17+
18+
import pytest
19+
import pandas as pd
20+
21+
from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_ROOT_DIR
22+
from evadb.interfaces.relational.db import connect
23+
from evadb.server.command_handler import execute_query_fetch_all
24+
25+
def Func_SimpleUDF(cls, x:int)->int:
26+
return x + 10
27+
28+
@pytest.mark.notparallel
29+
class SimpleFunctionTests(unittest.TestCase):
30+
def setUp(self):
31+
self.db_dir = suffix_pytest_xdist_worker_id_to_dir(EvaDB_DATABASE_DIR)
32+
self.conn = connect(self.db_dir)
33+
self.evadb = self.conn._evadb
34+
self.evadb.catalog().reset()
35+
36+
def tearDown(self):
37+
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS test_table;")
38+
execute_query_fetch_all(self.evadb, "DROP FUNCTION IF EXISTS My_SimpleUDF;")
39+
execute_query_fetch_all(self.evadb, "DROP FUNCTION IF EXISTS Func_SimpleUDF;")
40+
41+
def test_from_file(self):
42+
cursor = self.conn.cursor()
43+
44+
execute_query_fetch_all(self.evadb, "CREATE TABLE IF NOT EXISTS test_table (val INTEGER);")
45+
cursor.insert("test_table", "(val)", "(1)").df()
46+
47+
cursor.create_function(
48+
"My_SimpleUDF",
49+
True,
50+
f"{EvaDB_ROOT_DIR}/evadb/functions/My_SimpleUDF.py",
51+
).df()
52+
53+
result = cursor.query("SELECT My_SimpleUDF(val) FROM test_table;").df()
54+
expected = pd.DataFrame({'output': [6]})
55+
56+
self.assertTrue(expected.equals(result))
57+
58+
def test_from_function(self):
59+
cursor = self.conn.cursor()
60+
61+
execute_query_fetch_all(self.evadb, "CREATE TABLE IF NOT EXISTS test_table (val INTEGER);")
62+
cursor.insert("test_table", "(val)", "(1)").df()
63+
64+
cursor.create_simple_function(
65+
"Func_SimpleUDF",
66+
Func_SimpleUDF,
67+
True,
68+
).df()
69+
70+
result = cursor.query("SELECT Func_SimpleUDF(val) FROM test_table;").df()
71+
expected = pd.DataFrame({'output': [11]})
72+
73+
self.assertTrue(expected.equals(result))

0 commit comments

Comments
 (0)