From 24b75b67eee393544789a71fec87a84a31c30c18 Mon Sep 17 00:00:00 2001 From: Barbagus42 Date: Wed, 8 Nov 2023 20:59:35 +0100 Subject: [PATCH] Use async interface to database --- .cspell.json | 24 ------------------------ pyproject.toml | 1 + ter/config.py | 2 ++ ter/helpers.py | 36 ++++++++++++++++++++++++++---------- ter/main.py | 21 +++++++++++++++++++-- ter/routes.py | 8 ++++---- 6 files changed, 52 insertions(+), 40 deletions(-) delete mode 100644 .cspell.json diff --git a/.cspell.json b/.cspell.json deleted file mode 100644 index 8c48770..0000000 --- a/.cspell.json +++ /dev/null @@ -1,24 +0,0 @@ -// cSpell Settings -{ - // Version of the setting file. Always 0.2 - "version": "0.2", - // language - current active spelling language - "language": "en", - // words - list of words to be always considered correct - "words": [ - "executemany", - "executescript", - "GTFS", - "headsign" - ], - // flagWords - list of words to be always considered incorrect - // This is useful for offensive words and common spelling errors. - // For example "hte" should be "the" - "flagWords": [], - "overrides": [ - { - "language": "fr-FR", - "filename": "**.md" - } - ] -} diff --git a/pyproject.toml b/pyproject.toml index 6f34f0e..0c76830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "pydantic-settings", "jinja2", "jinja2-fragments", + "aiosqlite", ] requires-python = ">=3.11" diff --git a/ter/config.py b/ter/config.py index 5425da8..4ff515a 100644 --- a/ter/config.py +++ b/ter/config.py @@ -11,6 +11,8 @@ class Settings(BaseSettings): TEMPLATE_DIR: str = path.join(APP_DIR, "templates") DATA_DIR: str = path.join(path.dirname(APP_DIR), "data") + SQLITE_URI: str = f"file:{path.join(DATA_DIR, 'db.sqlite')}?mode=ro" + FASTAPI_PROPERTIES: dict = { "title": "TER", "description": "", diff --git a/ter/helpers.py b/ter/helpers.py index d193b53..44593c0 100644 --- a/ter/helpers.py +++ b/ter/helpers.py @@ -1,16 +1,32 @@ -import sqlite3 -import os.path as path -import contextlib as ctx +from typing import Any +from collections.abc import Sequence, Iterable + +import aiosqlite from ter.config import Settings settings = Settings() -def connect_db(): - return ctx.closing( - sqlite3.connect( - f"file:{path.join(settings.DATA_DIR, 'db.sqlite')}?mode=ro", - uri=True, - ) - ) +class Database: + Params = Sequence[Any] | dict[str, Any] + + def __init__(self, uri: str) -> None: + self._uri: str = uri + self._connection: aiosqlite.Connection | None = None + + async def connect(self): + self._connection = await aiosqlite.connect(self._uri, uri=True) + + async def disconnect(self): + await self._connection.close() + self._connection = None + + async def execute(self, sql: str, params: Params = ()): + return await self._connection.execute(sql, params) + + async def executemany(self, sql: str, params: Iterable[Params]): + return await self._connection.executemany(sql, params) + + +database = Database(settings.SQLITE_URI) diff --git a/ter/main.py b/ter/main.py index 9135924..dcf9805 100644 --- a/ter/main.py +++ b/ter/main.py @@ -1,15 +1,32 @@ +import contextlib as ctx + from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from ter.config import Settings +from ter.helpers import database from ter.routes import router settings = Settings() +@ctx.asynccontextmanager +async def lifespan(app: FastAPI): + await database.connect() + yield + await database.disconnect() + + def get_app() -> FastAPI: - app = FastAPI(**settings.FASTAPI_PROPERTIES) - app.mount("/static", StaticFiles(directory=settings.STATIC_DIR), name="static") + app = FastAPI( + **settings.FASTAPI_PROPERTIES, + lifespan=lifespan, + ) + app.mount( + "/static", + StaticFiles(directory=settings.STATIC_DIR), + name="static", + ) app.include_router(router) return app diff --git a/ter/routes.py b/ter/routes.py index b38bedb..1f2b7e2 100644 --- a/ter/routes.py +++ b/ter/routes.py @@ -3,7 +3,7 @@ from jinja2_fragments.fastapi import Jinja2Blocks from ter.config import Settings -from ter.helpers import connect_db +from ter.helpers import database settings = Settings() router = APIRouter() @@ -11,11 +11,11 @@ templates = Jinja2Blocks(settings.TEMPLATE_DIR) @router.get("/") -def index(request: Request): +async def index(request: Request): """Home page.""" - with connect_db() as db: - agencies = db.execute("SELECT * FROM agency").fetchall() + cursor = await database.execute("SELECT * FROM agency") + agencies = await cursor.fetchall() context = {"request": request, "agencies": agencies}