1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+
1415#
1516# Author(s): Manish Sahani <[email protected] > 16-
1717import uuid
1818import sqlalchemy as S
1919
2020from sqlalchemy .orm import sessionmaker , scoped_session
2121from sqlalchemy .ext .declarative import declarative_base
2222from sqlalchemy .types import TypeDecorator , CHAR
23+ from sqlalchemy import event
2324
2425
2526BASE = declarative_base ()
2627
28+ """
29+ schema version, remember to update this
30+ whenever you make changes to the schema
31+ """
32+ schema_version = 2
33+
2734
2835def create_session (url ):
2936 """
@@ -46,6 +53,7 @@ def create_session(url):
4653def migrate (url ):
4754 """
4855 Create the tables in the database using the url
56+ Check if we need to upgrade the schema, and do that as well
4957
5058 Args:
5159 url (string): the URL used to connect the application to the
@@ -54,12 +62,74 @@ def migrate(url):
5462 ie: <engine>://<user>:<password>@<host>/<dbname>
5563 """
5664 engine = S .create_engine (url )
65+ update_schema (engine , schema_version )
5766 BASE .metadata .create_all (bind = engine )
67+
5868
5969 session = scoped_session (
6070 sessionmaker (bind = engine , autocommit = False , autoflush = False )
6171 )
6272 return session
73+
74+
75+ def update_schema (engine , schema_version ):
76+ """
77+ Primitive database schema upgrade facility, designed to work
78+ with production Elekto databases
79+
80+ Currently only works with PostgreSQL due to requiring transaction
81+ support for DDL statements. MySQL, SQLite backends will error.
82+
83+ Start by figuring out our schema version, and then upgrade
84+ stepwise until we match
85+ """
86+ db_version = 1
87+ db_schema = S .inspect (engine )
88+
89+ if db_schema .has_table ("election" ):
90+ if db_schema .has_table ("schema_version" ):
91+ db_version = engine .execute ('select version from schema_version' ).scalar ()
92+ if db_version is None :
93+ """ intialize the table, if necessary """
94+ engine .execute ('insert into schema_version ( version ) values ( 2 )' )
95+ else :
96+ """ new, empty db """
97+ return schema_version
98+
99+ while db_version < schema_version :
100+ if engine .dialect .name != "postgresql" :
101+ raise RuntimeError ('Upgrading the schema is required, but the database is not PostgreSQL. You will need to upgrade manually.' )
102+
103+ if db_version < 2 :
104+ db_version = update_schema_2 (engine )
105+ continue
106+
107+ return db_version ;
108+
109+
110+ def update_schema_2 (engine ):
111+ """
112+ update from schema version 1 to schema version 2
113+ as a set of raw SQL statements
114+ currently only works for PostgreSQL
115+ written this way because SQLalchemy can't handle the
116+ steps involved without data loss
117+ """
118+ session = scoped_session (sessionmaker (bind = engine ))
119+
120+ session .execute ('CREATE TABLE schema_version ( version INT PRIMARY KEY);' )
121+ session .execute ('INSERT INTO schema_version VALUES ( 2 );' )
122+ session .execute ('ALTER TABLE voter ADD COLUMN salt BYTEA, ADD COLUMN ballot_id BYTEA;' )
123+ session .execute ('CREATE INDEX voter_election_id ON voter(election_id);' )
124+ session .execute ('ALTER TABLE ballot DROP COLUMN created_at, DROP COLUMN updated_at;' )
125+ session .execute ('ALTER TABLE ballot DROP CONSTRAINT ballot_pkey;' )
126+ session .execute ("ALTER TABLE ballot ALTER COLUMN id TYPE CHAR(32) USING to_char(id , 'FM00000000000000000000000000000000');" )
127+ session .execute ('ALTER TABLE ballot ALTER COLUMN id DROP DEFAULT;' )
128+ session .execute ('ALTER TABLE ballot ADD CONSTRAINT ballot_pkey PRIMARY KEY ( id );' )
129+ session .execute ('CREATE INDEX ballot_election_id ON ballot(election_id);' )
130+ session .commit ()
131+
132+ return 2
63133
64134
65135class UUID (TypeDecorator ):
@@ -94,6 +164,19 @@ def process_result_value(self, value, dialect):
94164 return value
95165
96166
167+ class Version (BASE ):
168+ """
169+ Stores Elekto schema version in the database for ad-hoc upgrades
170+ """
171+ __tablename__ = "schema_version"
172+
173+ # Attributes
174+ version = S .Column (S .Integer , default = schema_version , primary_key = True )
175+
176+ @event .listens_for (Version .__table__ , 'after_create' )
177+ def create_version (target , connection , ** kwargs ):
178+ connection .execute (f"INSERT INTO schema_version ( version ) VALUES ( { schema_version } )" )
179+
97180class User (BASE ):
98181 """
99182 User Schema - registered from the oauth external application - github
@@ -185,11 +268,11 @@ class Voter(BASE):
185268
186269 id = S .Column (S .Integer , primary_key = True )
187270 user_id = S .Column (S .Integer , S .ForeignKey ("user.id" , ondelete = "CASCADE" ))
188- election_id = S .Column (S .Integer , S .ForeignKey ("election.id" , ondelete = "CASCADE" ))
271+ election_id = S .Column (S .Integer , S .ForeignKey ("election.id" , ondelete = "CASCADE" ), index = True )
189272 created_at = S .Column (S .DateTime , default = S .func .now ())
190273 updated_at = S .Column (S .DateTime , default = S .func .now ())
191- salt = S .Column (S .LargeBinary , nullable = False )
192- ballot_id = S .Column (S .LargeBinary , nullable = False ) # encrypted
274+ salt = S .Column (S .LargeBinary )
275+ ballot_id = S .Column (S .LargeBinary ) # encrypted
193276
194277 # Relationships
195278
@@ -227,7 +310,7 @@ class Ballot(BASE):
227310
228311 # Attributes
229312 id = S .Column (UUID (), primary_key = True , default = uuid .uuid4 )
230- election_id = S .Column (S .Integer , S .ForeignKey ("election.id" , ondelete = "CASCADE" ))
313+ election_id = S .Column (S .Integer , S .ForeignKey ("election.id" , ondelete = "CASCADE" ), index = True )
231314 rank = S .Column (S .Integer , default = 100000000 )
232315 candidate = S .Column (S .String (255 ), nullable = False )
233316 voter = S .Column (S .String (255 ), nullable = False ) # uuid
0 commit comments