diff --git a/desktop/libs/librdbms/src/librdbms/server/dbms.py b/desktop/libs/librdbms/src/librdbms/server/dbms.py index 5a6831f06c..1256e784bf 100644 --- a/desktop/libs/librdbms/src/librdbms/server/dbms.py +++ b/desktop/libs/librdbms/src/librdbms/server/dbms.py @@ -28,6 +28,7 @@ MYSQL = 'mysql' POSTGRESQL = 'postgresql' SQLITE = 'sqlite' ORACLE = 'oracle' +ODBC = 'odbc' def get(user, query_server=None): @@ -50,6 +51,11 @@ def get(user, query_server=None): from librdbms.server.oracle_lib import OracleClient return Rdbms(OracleClient(query_server, user), ORACLE) + elif query_server['server_name'] == 'django-pyodbc': + from librdbms.server.pyodbc_lib import ODBCClient + + return Rdbms(ODBCClient(query_server, user), ODBC) + def get_query_server_config(server=None): diff --git a/desktop/libs/librdbms/src/librdbms/server/pyodbc_lib.py b/desktop/libs/librdbms/src/librdbms/server/pyodbc_lib.py new file mode 100644 index 0000000000..83832f4025 --- /dev/null +++ b/desktop/libs/librdbms/src/librdbms/server/pyodbc_lib.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +try: + import pyodbc as Database +except ImportError, e: + from django.core.exceptions import ImproperlyConfigured + raise ImproperlyConfigured("Error loading pyodbc module: %s" % e) + +from librdbms.server.rdbms_base_lib import BaseRDBMSDataTable, BaseRDBMSResult, BaseRDMSClient + + +LOG = logging.getLogger(__name__) + + +class DataTable(BaseRDBMSDataTable): + """Fixing cursor handling for inserts/updates""" + @property + def has_more(self): + if not self.next: + if self.cursor.rowcount == -1: + rows = self.cursor.fetchmany(self.fetchSize) + self.next = list(list(row) for row in rows) + return bool(self.next) + + +class Result(BaseRDBMSResult): pass + + +class ODBCClient(BaseRDMSClient): + """Same API as Beeswax""" + + data_table_cls = DataTable + result_cls = Result + + def __init__(self, *args, **kwargs): + super(ODBCClient, self).__init__(*args, **kwargs) + self.connection = Database.connect(self._conn_params) + self.connection.autocommit = True + + @property + def _conn_params(self): + params = { + 'user': self.query_server['username'], + 'password': self.query_server['password'], + 'host': self.query_server['server_host'], + 'port': self.query_server['server_port'] == 0 and 5432 or self.query_server['server_port'], + 'database': self.query_server['name'] + } + + if self.query_server['options']: + params.update(self.query_server['options']) + + return 'DSN=%s;UID=%s;PWD=%s' % (self.query_server['name'],self.query_server['username'],self.query_server['password']) + + def use(self, database): + # No op since postgresql requires a new connection per database + pass + + + def execute_statement(self, statement): + cursor = self.connection.cursor() + cursor.execute(statement) + if cursor.description: + columns = [column[0] for column in cursor.description] + else: + columns = [] + return self.data_table_cls(cursor, columns) + + + def get_databases(self): + # List all the schemas in the database + try: + cursor = self.connection.cursor() + cursor.tables(schema='%', table='', catalog='') # Empty Strings per the ODBC Spec. 1980 we want our computing back. + return [row[1].encode('utf8') for row in cursor.fetchall()] + except Exception: + LOG.exception('Failed to fetch databases from pyodbc cursor.tables') + return [self._conn_params['database']] + + # To Do Table_Names + def get_tables(self, database, table_names=[]): + cursor = self.connection.cursor() + cursor.tables(schema=database) + return [row[2].encode('utf8') for row in cursor.fetchall()] + + def get_columns(self, database, table, names_only=True): + cursor = self.connection.cursor() + cursor.columns(table=table, schema=database) + if names_only: + columns = [row[3].encode('utf8') for row in cursor.fetchall()] + else: + columns = [dict(name=row[3].encode('utf8'), type=row[5], comment='') for row in cursor.fetchall()] + return columns + + def get_sample_data(self, database, table, column=None, limit=100): + column = '"%s"' % column if column else '*' + statement = 'SELECT %s FROM "%s"."%s" {LIMIT %d}' % (column, database, table, limit) + return self.execute_statement(statement)