Skip to content

Commit d810d7d

Browse files
authored
feat: add pg8000 support (#40)
* feat: add support for connecting with pg8000 * add docstring * linting * update _connect_with_pg8000 to work with updated version of pg8000 * update env var names * subclass ssl.SSLContext to get python3.6 tests passing * blacken/lint * fix typo in docstring
1 parent 4edcf11 commit d810d7d

4 files changed

Lines changed: 129 additions & 2 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
env
33
venv
44
*.pyc
5+
.python-version

google/cloud/sql/connector/InstanceConnectionManager.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@
4242
_sql_api_version: str = "v1beta4"
4343

4444

45+
class ConnectionSSLContext(ssl.SSLContext):
46+
"""Subclass of ssl.SSLContext with added request_ssl attribute. This is
47+
required for compatibility with pg8000 driver.
48+
"""
49+
50+
def __init__(self, *args, **kwargs):
51+
self.request_ssl = False
52+
super(ConnectionSSLContext, self).__init__(*args, **kwargs)
53+
54+
4555
class InstanceMetadata:
4656
ip_address: str
4757
_ca_fileobject: NamedTemporaryFile
@@ -73,7 +83,7 @@ def __init__(
7383
self._cert_fileobject.seek(0)
7484
self._key_fileobject.seek(0)
7585

76-
self.context = ssl.SSLContext()
86+
self.context = ConnectionSSLContext()
7787
self.context.load_cert_chain(
7888
self._cert_fileobject.name, keyfile=self._key_fileobject.name
7989
)
@@ -455,8 +465,13 @@ def connect(self, driver: str, **kwargs) -> Any:
455465
with self._lock:
456466
instance_data: InstanceMetadata = self._current.result()
457467

468+
connect_func = {
469+
"pymysql": self._connect_with_pymysql,
470+
"pg8000": self._connect_with_pg8000,
471+
}
472+
458473
try:
459-
connector = {"pymysql": self._connect_with_pymysql}[driver]
474+
connector = connect_func[driver]
460475
except KeyError:
461476
raise KeyError("Driver {} is not supported.".format(driver))
462477

@@ -499,3 +514,38 @@ def _connect_with_pymysql(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
499514
conn = pymysql.Connection(host=ip_address, defer_connect=True, **kwargs)
500515
conn.connect(sock)
501516
return conn
517+
518+
def _connect_with_pg8000(self, ip_address: str, ctx: ssl.SSLContext, **kwargs):
519+
"""Helper function to create a pg8000 DB-API connection object.
520+
521+
:type ip_address: str
522+
:param ip_address: A string containing an IP address for the Cloud SQL
523+
instance.
524+
525+
:type ctx: ssl.SSLContext
526+
:param ctx: An SSLContext object created from the Cloud SQL server CA
527+
cert and ephemeral cert.
528+
529+
530+
:rtype: pg8000.dbapi.Connection
531+
:returns: A pg8000 Connection object for the Cloud SQL instance.
532+
"""
533+
try:
534+
import pg8000
535+
except ImportError:
536+
raise ImportError(
537+
'Unable to import module "pg8000." Please install and try again.'
538+
)
539+
user = kwargs.pop("user")
540+
db = kwargs.pop("db")
541+
passwd = kwargs.pop("password")
542+
ctx.request_ssl = False
543+
return pg8000.dbapi.connect(
544+
user,
545+
database=db,
546+
password=passwd,
547+
host=ip_address,
548+
port=3307,
549+
ssl_context=ctx,
550+
**kwargs
551+
)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
aiohttp==3.7.4
22
cryptography==3.4.6
33
PyMySQL==1.0.2
4+
#pg8000==1.17.0
5+
git+https://github.com/tlocke/pg8000.git@37dcbe3#egg=pg8000
46
pyopenssl==20.0.1
57
pytest==6.2.2
68
Requests==2.25.1
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
""""
2+
Copyright 2021 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import os
17+
import uuid
18+
19+
import pytest
20+
import sqlalchemy
21+
from google.cloud.sql.connector import connector
22+
23+
table_name = f"books_{uuid.uuid4().hex}"
24+
25+
26+
def init_connection_engine():
27+
def getconn():
28+
conn = connector.connect(
29+
os.environ["POSTGRES_CONNECTION_NAME"],
30+
"pg8000",
31+
user=os.environ["POSTGRES_USER"],
32+
password=os.environ["POSTGRES_PASS"],
33+
db=os.environ["POSTGRES_DB"],
34+
)
35+
return conn
36+
37+
engine = sqlalchemy.create_engine(
38+
"postgresql+pg8000://",
39+
creator=getconn,
40+
)
41+
engine.dialect.description_encoding = None
42+
return engine
43+
44+
45+
@pytest.fixture(name="pool")
46+
def setup():
47+
pool = init_connection_engine()
48+
49+
with pool.connect() as conn:
50+
conn.execute(
51+
f"CREATE TABLE IF NOT EXISTS {table_name}"
52+
" ( id CHAR(20) NOT NULL, title TEXT NOT NULL );"
53+
)
54+
55+
yield pool
56+
57+
with pool.connect() as conn:
58+
conn.execute(f"DROP TABLE IF EXISTS {table_name}")
59+
60+
61+
def test_pooled_connection_with_pymysql(pool):
62+
insert_stmt = sqlalchemy.text(
63+
f"INSERT INTO {table_name} (id, title) VALUES (:id, :title)",
64+
)
65+
with pool.connect() as conn:
66+
conn.execute(insert_stmt, id="book1", title="Book One")
67+
conn.execute(insert_stmt, id="book2", title="Book Two")
68+
69+
select_stmt = sqlalchemy.text(f"SELECT title FROM {table_name} ORDER BY ID;")
70+
with pool.connect() as conn:
71+
rows = conn.execute(select_stmt).fetchall()
72+
titles = [row[0] for row in rows]
73+
74+
assert titles == ["Book One", "Book Two"]

0 commit comments

Comments
 (0)