Typing and commenting update.py
This commit is contained in:
parent
410c5e44e8
commit
3968666904
|
@ -6,12 +6,15 @@
|
|||
"language": "en",
|
||||
// words - list of words to be always considered correct
|
||||
"words": [
|
||||
"executemany",
|
||||
"executescript",
|
||||
"GTFS",
|
||||
"headsign"
|
||||
],
|
||||
// flagWords - list of words to be always considered incorrect
|
||||
// This is useful for offensive words and common spelling errors.
|
||||
// For example "hte" should be "the"
|
||||
"flagWords": [
|
||||
],
|
||||
"flagWords": [],
|
||||
"overrides": [
|
||||
{
|
||||
"language": "fr-FR",
|
||||
|
|
|
@ -7,6 +7,12 @@ import os.path
|
|||
import math
|
||||
import contextlib as ctx
|
||||
|
||||
from typing import Callable, Any, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from http.client import HTTPResponse
|
||||
from zipfile import ZipFile
|
||||
from sqlite3 import Connection as DBconnection
|
||||
|
||||
# GTFS reference:
|
||||
# https://gtfs.org/schedule/reference/
|
||||
|
@ -17,56 +23,68 @@ import contextlib as ctx
|
|||
# SNCF/TER daily updated dataset:
|
||||
GTFS_URL = "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip"
|
||||
|
||||
|
||||
Cell = int | str | float | None
|
||||
Mapper = Callable[[Any], Cell]
|
||||
|
||||
# This data set uses large strings as primary/foreign keys that do not carry
|
||||
# information. In order to save some space in the database and and time during
|
||||
# requests, we substitute those large IDs with incrementing integers.
|
||||
|
||||
# Global register for primary key substitution
|
||||
PKS: dict[str, dict[str, int]] = {}
|
||||
|
||||
|
||||
def primary_key(table):
|
||||
# Primary key substitution requester
|
||||
def primary_key(table: str) -> Mapper:
|
||||
assert table not in PKS
|
||||
PKS[table] = {}
|
||||
|
||||
def map(v):
|
||||
def map(v: str):
|
||||
PKS[table][v] = len(PKS[table]) + 1
|
||||
return len(PKS[table])
|
||||
|
||||
return map
|
||||
|
||||
|
||||
def foreign_key(table):
|
||||
def map(v):
|
||||
# Foreign key lookup
|
||||
def foreign_key(table: str) -> Mapper:
|
||||
def map(v: str):
|
||||
return PKS[table][v]
|
||||
|
||||
return map
|
||||
|
||||
|
||||
def optional(f):
|
||||
def map(v):
|
||||
# A "can be null" mapper wrapper
|
||||
def optional(f: Mapper) -> Mapper:
|
||||
def map(v: str):
|
||||
return None if v == "" else f(v)
|
||||
|
||||
return map
|
||||
|
||||
|
||||
CSV_FIELDS = (
|
||||
CSV_FIELDS: list[tuple[str, Optional[list[tuple[str, Optional[Mapper]]]]]] = [
|
||||
(
|
||||
"agency.txt",
|
||||
(
|
||||
[
|
||||
("agency_id", primary_key("agency")),
|
||||
("agency_name", str),
|
||||
("agency_url", str),
|
||||
("agency_timezone", str),
|
||||
("agency_lang", str),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
"calendar_dates.txt",
|
||||
(
|
||||
[
|
||||
("service_id", int),
|
||||
("date", str),
|
||||
("exception_type", int),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
"routes.txt",
|
||||
(
|
||||
[
|
||||
("route_id", primary_key("routes")),
|
||||
("agency_id", foreign_key("agency")),
|
||||
("route_short_name", str),
|
||||
|
@ -76,11 +94,11 @@ CSV_FIELDS = (
|
|||
("route_url", None),
|
||||
("route_color", optional(str)),
|
||||
("route_text_color", optional(str)),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
"trips.txt",
|
||||
(
|
||||
[
|
||||
("route_id", foreign_key("routes")),
|
||||
("service_id", int),
|
||||
("trip_id", primary_key("trips")),
|
||||
|
@ -88,11 +106,11 @@ CSV_FIELDS = (
|
|||
("direction_id", optional(int)),
|
||||
("block_id", int),
|
||||
("shape_id", None),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
"stops.txt",
|
||||
(
|
||||
[
|
||||
("stop_id", primary_key("stops")),
|
||||
("stop_name", str),
|
||||
("stop_desc", None),
|
||||
|
@ -102,11 +120,11 @@ CSV_FIELDS = (
|
|||
("stop_url", None),
|
||||
("location_type", int),
|
||||
("parent_station", optional(foreign_key("stops"))),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
"stop_times.txt",
|
||||
(
|
||||
[
|
||||
("trip_id", foreign_key("trips")),
|
||||
("arrival_time", str),
|
||||
("departure_time", str),
|
||||
|
@ -116,14 +134,14 @@ CSV_FIELDS = (
|
|||
("pickup_type", int),
|
||||
("drop_off_type", int),
|
||||
("shape_dist_traveled", None),
|
||||
),
|
||||
],
|
||||
),
|
||||
("feed_info.txt", None),
|
||||
("transfers.txt", None),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _get_file_names(etag):
|
||||
def _get_file_names(etag: str):
|
||||
dir = os.path.dirname(__file__)
|
||||
return (
|
||||
os.path.join(dir, etag + ".zip"),
|
||||
|
@ -133,7 +151,7 @@ def _get_file_names(etag):
|
|||
)
|
||||
|
||||
|
||||
def _fetch_dataset(response, dataset_file):
|
||||
def _fetch_dataset(response: HTTPResponse, dataset_file: str):
|
||||
print("Fetching dataset...")
|
||||
content_length = int(response.getheader("Content-Length"))
|
||||
with open(dataset_file, "wb") as zip_out:
|
||||
|
@ -146,7 +164,7 @@ def _fetch_dataset(response, dataset_file):
|
|||
print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%")
|
||||
|
||||
|
||||
def _check_dataset_files(zip_in):
|
||||
def _check_dataset_files(zip_in: ZipFile):
|
||||
csv_files = list(sorted(zip_in.namelist()))
|
||||
expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS))
|
||||
|
||||
|
@ -154,20 +172,20 @@ def _check_dataset_files(zip_in):
|
|||
assert all(a == b for a, b in zip(csv_files, expected, strict=True)), csv_files
|
||||
|
||||
|
||||
def _check_csv_headers(csv_headers, fields):
|
||||
def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]):
|
||||
expected = list(header for header, _mapper in fields)
|
||||
|
||||
assert len(expected) == len(csv_headers), csv_headers
|
||||
assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers
|
||||
|
||||
|
||||
def _create_schema(db, schema_file):
|
||||
def _create_schema(db: DBconnection, schema_file: str):
|
||||
with db:
|
||||
with open(schema_file) as sql_in:
|
||||
db.executescript(sql_in.read())
|
||||
|
||||
|
||||
def _load_data(zip_in, db):
|
||||
def _load_data(zip_in: ZipFile, db: DBconnection):
|
||||
for csv_file, fields in CSV_FIELDS:
|
||||
if fields is None:
|
||||
continue
|
||||
|
@ -204,7 +222,7 @@ def _load_data(zip_in, db):
|
|||
)
|
||||
|
||||
|
||||
def _load_database(dataset_file, database_tempfile, schema_file):
|
||||
def _load_database(dataset_file: str, database_tempfile: str, schema_file: str):
|
||||
print("Loading database...")
|
||||
with zipfile.ZipFile(dataset_file) as zip_in:
|
||||
_check_dataset_files(zip_in)
|
||||
|
@ -218,7 +236,7 @@ def _load_database(dataset_file, database_tempfile, schema_file):
|
|||
|
||||
def main():
|
||||
print("Checking dataset")
|
||||
response = request.urlopen(GTFS_URL)
|
||||
response: HTTPResponse = request.urlopen(GTFS_URL)
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Could not fetch the dataset")
|
||||
|
||||
|
|
Loading…
Reference in New Issue