Skip to content

Commit b992623

Browse files
rad-patpan3793
andcommitted
[KYUUBI #7190] Fix Presto SQLAlchemy dialect did not implement get_view_names
Presto SQLAlchemy dialect did not implement the `get_view_names` method and resulted in an exception when trying to inspect the schema. This was discovered in Superset repo whilst trying to update the pandas package which now makes a call to `get_view_names`. Very basic Python tests have been added here, but all the SQLAlchemy dialects in this repo would benefit from running the full SQLAlchemy dialect test suite instead of these bespoke tests. Closes #7190 from rad-pat/fix-presto-dialect. Closes #7190 c2d06f7 [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py 2739697 [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py 1c7b628 [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py 2e7040a [Cheng Pan] Update python/pyhive/sqlalchemy_presto.py 89d3f55 [Cheng Pan] Update python/pyhive/__init__.py b8deadc [Pat Buxton] Bump python version to 0.7.1 ab829ee [Pat Buxton] Fix - Presto SQLAlchemy dialect did not implement get_view_names Lead-authored-by: Pat Buxton <[email protected]> Co-authored-by: Cheng Pan <[email protected]> Signed-off-by: Cheng Pan <[email protected]>
1 parent 0c56e65 commit b992623

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

python/pyhive/sqlalchemy_presto.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sqlalchemy.dialects import mysql
2424
mysql_tinyinteger = mysql.base.MSTinyInteger
2525
from sqlalchemy.engine import default
26-
from sqlalchemy.sql import compiler
26+
from sqlalchemy.sql import compiler, bindparam
2727
from sqlalchemy.sql.compiler import SQLCompiler
2828

2929
from pyhive import presto
@@ -204,12 +204,45 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
204204
else:
205205
return []
206206

207+
def _get_default_schema_name(self, connection):
208+
#'SELECT CURRENT_SCHEMA()'
209+
return super()._get_default_schema_name(connection)
210+
207211
def get_table_names(self, connection, schema=None, **kw):
208212
query = 'SHOW TABLES'
213+
# N.B. This is incorrect, if no schema is provided, the current/default schema should be used
214+
# with a call to an overridden self._get_default_schema_name(connection), but I could not
215+
# see how to implement that as there is no CURRENT_SCHEMA function
216+
# default_schema = self._get_default_schema_name(connection)
217+
209218
if schema:
210219
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
211220
return [row.Table for row in connection.execute(text(query))]
212221

222+
def get_view_names(self, connection, schema=None, **kw):
223+
if schema:
224+
view_name_query = """
225+
SELECT table_name
226+
FROM information_schema.views
227+
WHERE table_schema = :schema
228+
"""
229+
query = text(view_name_query).bindparams(
230+
bindparam("schema", type_=types.Unicode)
231+
)
232+
else:
233+
# N.B. This is incorrect, if no schema is provided, the current/default schema should
234+
# be used with a call to self._get_default_schema_name(connection), but I could not
235+
# see how to implement that
236+
# default_schema = self._get_default_schema_name(connection)
237+
view_name_query = """
238+
SELECT table_name
239+
FROM information_schema.views
240+
"""
241+
query = text(view_name_query)
242+
243+
result = connection.execute(query, dict(schema=schema))
244+
return [row[0] for row in result]
245+
213246
def do_rollback(self, dbapi_connection):
214247
# No transactions for Presto
215248
pass

python/pyhive/tests/test_sqlalchemy_presto.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,31 @@ def test_hash_table(self, engine, connection):
102102
self.assertFalse(insp.has_table("THIS_TABLE_DOSE_not_exist"))
103103
else:
104104
self.assertFalse(Table('THIS_TABLE_DOSE_NOT_EXIST', MetaData(bind=engine)).exists())
105-
self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists())
105+
self.assertFalse(Table('THIS_TABLE_DOSE_not_exits', MetaData(bind=engine)).exists())
106+
107+
@with_engine_connection
108+
def test_reflect_table_names(self, engine, connection):
109+
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
110+
if sqlalchemy_version >= 1.4:
111+
insp = sqlalchemy.inspect(engine)
112+
table_names = insp.get_table_names()
113+
self.assertIn("one_row", table_names)
114+
self.assertIn("one_row_complex", table_names)
115+
self.assertIn("many_rows", table_names)
116+
self.assertNotIn("THIS_TABLE_DOES_not_exist", table_names)
117+
else:
118+
self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
119+
self.assertTrue(Table('one_row_complex', MetaData(bind=engine)).exists())
120+
self.assertTrue(Table('many_rows', MetaData(bind=engine)).exists())
121+
self.assertFalse(Table('THIS_TABLE_DOES_not_exist', MetaData(bind=engine)).exists())
122+
123+
@with_engine_connection
124+
def test_reflect_view_names(self, engine, connection):
125+
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
126+
if sqlalchemy_version >= 1.4:
127+
insp = sqlalchemy.inspect(engine)
128+
view_names = insp.get_view_names()
129+
self.assertNotIn("one_row", view_names)
130+
self.assertNotIn("one_row_complex", view_names)
131+
self.assertNotIn("many_rows", view_names)
132+
self.assertNotIn("THIS_TABLE_DOES_not_exist", view_names)

0 commit comments

Comments
 (0)