diff --git a/.cspell.json b/.cspell.json index ef33973..8c48770 100644 --- a/.cspell.json +++ b/.cspell.json @@ -6,12 +6,15 @@ "language": "en", // words - list of words to be always considered correct "words": [ + "executemany", + "executescript", + "GTFS", + "headsign" ], // flagWords - list of words to be always considered incorrect // This is useful for offensive words and common spelling errors. // For example "hte" should be "the" - "flagWords": [ - ], + "flagWords": [], "overrides": [ { "language": "fr-FR", diff --git a/data/update.py b/data/update.py index b36f9f0..b7a3ea6 100644 --- a/data/update.py +++ b/data/update.py @@ -7,6 +7,12 @@ import os.path import math import contextlib as ctx +from typing import Callable, Any, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from http.client import HTTPResponse + from zipfile import ZipFile + from sqlite3 import Connection as DBconnection # GTFS reference: # https://gtfs.org/schedule/reference/ @@ -17,56 +23,68 @@ import contextlib as ctx # SNCF/TER daily updated dataset: GTFS_URL = "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip" + +Cell = int | str | float | None +Mapper = Callable[[Any], Cell] + +# This data set uses large strings as primary/foreign keys that do not carry +# information. In order to save some space in the database and and time during +# requests, we substitute those large IDs with incrementing integers. + +# Global register for primary key substitution PKS: dict[str, dict[str, int]] = {} -def primary_key(table): +# Primary key substitution requester +def primary_key(table: str) -> Mapper: assert table not in PKS PKS[table] = {} - def map(v): + def map(v: str): PKS[table][v] = len(PKS[table]) + 1 return len(PKS[table]) return map -def foreign_key(table): - def map(v): +# Foreign key lookup +def foreign_key(table: str) -> Mapper: + def map(v: str): return PKS[table][v] return map -def optional(f): - def map(v): +# A "can be null" mapper wrapper +def optional(f: Mapper) -> Mapper: + def map(v: str): return None if v == "" else f(v) return map -CSV_FIELDS = ( +CSV_FIELDS: list[tuple[str, Optional[list[tuple[str, Optional[Mapper]]]]]] = [ ( "agency.txt", - ( + [ ("agency_id", primary_key("agency")), ("agency_name", str), ("agency_url", str), ("agency_timezone", str), ("agency_lang", str), - ), + ], ), ( "calendar_dates.txt", - ( + [ ("service_id", int), ("date", str), ("exception_type", int), - ), + ], ), ( "routes.txt", - ( + [ ("route_id", primary_key("routes")), ("agency_id", foreign_key("agency")), ("route_short_name", str), @@ -76,11 +94,11 @@ CSV_FIELDS = ( ("route_url", None), ("route_color", optional(str)), ("route_text_color", optional(str)), - ), + ], ), ( "trips.txt", - ( + [ ("route_id", foreign_key("routes")), ("service_id", int), ("trip_id", primary_key("trips")), @@ -88,11 +106,11 @@ CSV_FIELDS = ( ("direction_id", optional(int)), ("block_id", int), ("shape_id", None), - ), + ], ), ( "stops.txt", - ( + [ ("stop_id", primary_key("stops")), ("stop_name", str), ("stop_desc", None), @@ -102,11 +120,11 @@ CSV_FIELDS = ( ("stop_url", None), ("location_type", int), ("parent_station", optional(foreign_key("stops"))), - ), + ], ), ( "stop_times.txt", - ( + [ ("trip_id", foreign_key("trips")), ("arrival_time", str), ("departure_time", str), @@ -116,14 +134,14 @@ CSV_FIELDS = ( ("pickup_type", int), ("drop_off_type", int), ("shape_dist_traveled", None), - ), + ], ), ("feed_info.txt", None), ("transfers.txt", None), -) +] -def _get_file_names(etag): +def _get_file_names(etag: str): dir = os.path.dirname(__file__) return ( os.path.join(dir, etag + ".zip"), @@ -133,7 +151,7 @@ def _get_file_names(etag): ) -def _fetch_dataset(response, dataset_file): +def _fetch_dataset(response: HTTPResponse, dataset_file: str): print("Fetching dataset...") content_length = int(response.getheader("Content-Length")) with open(dataset_file, "wb") as zip_out: @@ -146,7 +164,7 @@ def _fetch_dataset(response, dataset_file): print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%") -def _check_dataset_files(zip_in): +def _check_dataset_files(zip_in: ZipFile): csv_files = list(sorted(zip_in.namelist())) expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS)) @@ -154,20 +172,20 @@ def _check_dataset_files(zip_in): assert all(a == b for a, b in zip(csv_files, expected, strict=True)), csv_files -def _check_csv_headers(csv_headers, fields): +def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]): expected = list(header for header, _mapper in fields) assert len(expected) == len(csv_headers), csv_headers assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers -def _create_schema(db, schema_file): +def _create_schema(db: DBconnection, schema_file: str): with db: with open(schema_file) as sql_in: db.executescript(sql_in.read()) -def _load_data(zip_in, db): +def _load_data(zip_in: ZipFile, db: DBconnection): for csv_file, fields in CSV_FIELDS: if fields is None: continue @@ -204,7 +222,7 @@ def _load_data(zip_in, db): ) -def _load_database(dataset_file, database_tempfile, schema_file): +def _load_database(dataset_file: str, database_tempfile: str, schema_file: str): print("Loading database...") with zipfile.ZipFile(dataset_file) as zip_in: _check_dataset_files(zip_in) @@ -218,7 +236,7 @@ def _load_database(dataset_file, database_tempfile, schema_file): def main(): print("Checking dataset") - response = request.urlopen(GTFS_URL) + response: HTTPResponse = request.urlopen(GTFS_URL) if response.status != 200: raise RuntimeError("Could not fetch the dataset")