Typing and commenting update.py

This commit is contained in:
Barbagus42 2023-11-07 09:13:48 +01:00
parent 410c5e44e8
commit 3968666904
2 changed files with 51 additions and 30 deletions

View File

@ -6,12 +6,15 @@
"language": "en",
// words - list of words to be always considered correct
"words": [
"executemany",
"executescript",
"GTFS",
"headsign"
],
// flagWords - list of words to be always considered incorrect
// This is useful for offensive words and common spelling errors.
// For example "hte" should be "the"
"flagWords": [
],
"flagWords": [],
"overrides": [
{
"language": "fr-FR",

View File

@ -7,6 +7,12 @@ import os.path
import math
import contextlib as ctx
from typing import Callable, Any, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from http.client import HTTPResponse
from zipfile import ZipFile
from sqlite3 import Connection as DBconnection
# GTFS reference:
# https://gtfs.org/schedule/reference/
@ -17,56 +23,68 @@ import contextlib as ctx
# SNCF/TER daily updated dataset:
GTFS_URL = "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip"
Cell = int | str | float | None
Mapper = Callable[[Any], Cell]
# This data set uses large strings as primary/foreign keys that do not carry
# information. In order to save some space in the database and and time during
# requests, we substitute those large IDs with incrementing integers.
# Global register for primary key substitution
PKS: dict[str, dict[str, int]] = {}
def primary_key(table):
# Primary key substitution requester
def primary_key(table: str) -> Mapper:
assert table not in PKS
PKS[table] = {}
def map(v):
def map(v: str):
PKS[table][v] = len(PKS[table]) + 1
return len(PKS[table])
return map
def foreign_key(table):
def map(v):
# Foreign key lookup
def foreign_key(table: str) -> Mapper:
def map(v: str):
return PKS[table][v]
return map
def optional(f):
def map(v):
# A "can be null" mapper wrapper
def optional(f: Mapper) -> Mapper:
def map(v: str):
return None if v == "" else f(v)
return map
CSV_FIELDS = (
CSV_FIELDS: list[tuple[str, Optional[list[tuple[str, Optional[Mapper]]]]]] = [
(
"agency.txt",
(
[
("agency_id", primary_key("agency")),
("agency_name", str),
("agency_url", str),
("agency_timezone", str),
("agency_lang", str),
),
],
),
(
"calendar_dates.txt",
(
[
("service_id", int),
("date", str),
("exception_type", int),
),
],
),
(
"routes.txt",
(
[
("route_id", primary_key("routes")),
("agency_id", foreign_key("agency")),
("route_short_name", str),
@ -76,11 +94,11 @@ CSV_FIELDS = (
("route_url", None),
("route_color", optional(str)),
("route_text_color", optional(str)),
),
],
),
(
"trips.txt",
(
[
("route_id", foreign_key("routes")),
("service_id", int),
("trip_id", primary_key("trips")),
@ -88,11 +106,11 @@ CSV_FIELDS = (
("direction_id", optional(int)),
("block_id", int),
("shape_id", None),
),
],
),
(
"stops.txt",
(
[
("stop_id", primary_key("stops")),
("stop_name", str),
("stop_desc", None),
@ -102,11 +120,11 @@ CSV_FIELDS = (
("stop_url", None),
("location_type", int),
("parent_station", optional(foreign_key("stops"))),
),
],
),
(
"stop_times.txt",
(
[
("trip_id", foreign_key("trips")),
("arrival_time", str),
("departure_time", str),
@ -116,14 +134,14 @@ CSV_FIELDS = (
("pickup_type", int),
("drop_off_type", int),
("shape_dist_traveled", None),
),
],
),
("feed_info.txt", None),
("transfers.txt", None),
)
]
def _get_file_names(etag):
def _get_file_names(etag: str):
dir = os.path.dirname(__file__)
return (
os.path.join(dir, etag + ".zip"),
@ -133,7 +151,7 @@ def _get_file_names(etag):
)
def _fetch_dataset(response, dataset_file):
def _fetch_dataset(response: HTTPResponse, dataset_file: str):
print("Fetching dataset...")
content_length = int(response.getheader("Content-Length"))
with open(dataset_file, "wb") as zip_out:
@ -146,7 +164,7 @@ def _fetch_dataset(response, dataset_file):
print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%")
def _check_dataset_files(zip_in):
def _check_dataset_files(zip_in: ZipFile):
csv_files = list(sorted(zip_in.namelist()))
expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS))
@ -154,20 +172,20 @@ def _check_dataset_files(zip_in):
assert all(a == b for a, b in zip(csv_files, expected, strict=True)), csv_files
def _check_csv_headers(csv_headers, fields):
def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]):
expected = list(header for header, _mapper in fields)
assert len(expected) == len(csv_headers), csv_headers
assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers
def _create_schema(db, schema_file):
def _create_schema(db: DBconnection, schema_file: str):
with db:
with open(schema_file) as sql_in:
db.executescript(sql_in.read())
def _load_data(zip_in, db):
def _load_data(zip_in: ZipFile, db: DBconnection):
for csv_file, fields in CSV_FIELDS:
if fields is None:
continue
@ -204,7 +222,7 @@ def _load_data(zip_in, db):
)
def _load_database(dataset_file, database_tempfile, schema_file):
def _load_database(dataset_file: str, database_tempfile: str, schema_file: str):
print("Loading database...")
with zipfile.ZipFile(dataset_file) as zip_in:
_check_dataset_files(zip_in)
@ -218,7 +236,7 @@ def _load_database(dataset_file, database_tempfile, schema_file):
def main():
print("Checking dataset")
response = request.urlopen(GTFS_URL)
response: HTTPResponse = request.urlopen(GTFS_URL)
if response.status != 200:
raise RuntimeError("Could not fetch the dataset")