TER/data/update.py

307 lines
8.5 KiB
Python

import zipfile
import csv
import io
import sqlite3
import urllib.request as request
import os.path
import math
import contextlib as ctx
import http.client as http
from typing import Callable, Any
# GTFS reference:
# https://gtfs.org/schedule/reference/
# SNCF/TER dataset information:
# https://data.opendatasoft.com/explore/dataset/sncf-ter-gtfs%40datasncf/information/
# SNCF/TER daily updated dataset:
GTFS_URL = "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip"
Cell = int | str | float | None
Mapper = Callable[[Any], Cell]
# This data set uses large strings as primary/foreign keys that do not carry
# information. In order to save some space in the database and and time during
# requests, we substitute those large IDs with incrementing integers.
# Global register for primary key substitution
PKS: dict[str, dict[str, int]] = {}
# Primary key substitution requester
def primary_key(table: str) -> Mapper:
assert table not in PKS
PKS[table] = {}
def map(v: str):
PKS[table][v] = len(PKS[table]) + 1
return len(PKS[table])
return map
# Foreign key lookup
def foreign_key(table: str) -> Mapper:
def map(v: str):
return PKS[table][v]
return map
# A "can be null" mapper wrapper
def optional(f: Mapper) -> Mapper:
def map(v: str):
return None if v == "" else f(v)
return map
def discarded(v: str):
# No-op mapper for field intentionally not included in the schema
return None
CSV_FIELDS: list[tuple[str, list[tuple[str, Mapper | None]]]] = [
(
"agency.txt",
[
("agency_id", primary_key("agency")),
("agency_name", str),
("agency_url", str),
("agency_timezone", str),
("agency_lang", str),
],
),
(
"calendar_dates.txt",
[
("service_id", int),
("date", str),
("exception_type", int),
],
),
(
"routes.txt",
[
("route_id", primary_key("routes")),
("agency_id", foreign_key("agency")),
("route_short_name", str),
("route_long_name", str),
("route_desc", None),
("route_type", int),
("route_url", None),
("route_color", optional(str)),
("route_text_color", optional(str)),
],
),
(
"trips.txt",
[
("route_id", foreign_key("routes")),
("service_id", int),
("trip_id", primary_key("trips")),
("trip_headsign", str),
("direction_id", optional(int)),
("block_id", int),
("shape_id", None),
],
),
(
"stops.txt",
[
("stop_id", primary_key("stops")),
("stop_name", str),
("stop_desc", None),
("stop_lat", float),
("stop_lon", float),
("zone_id", None),
("stop_url", None),
("location_type", int),
("parent_station", optional(foreign_key("stops"))),
],
),
(
"stop_times.txt",
[
("trip_id", foreign_key("trips")),
("arrival_time", str),
("departure_time", str),
("stop_id", foreign_key("stops")),
("stop_sequence", int),
("stop_headsign", None),
("pickup_type", int),
("drop_off_type", int),
("shape_dist_traveled", None),
],
),
(
"feed_info.txt",
[
("feed_id", int),
("feed_publisher_name", str),
("feed_publisher_url", str),
("feed_lang", str),
("feed_start_date", str),
("feed_end_date", str),
("feed_version", str),
("conv_rev", discarded), # proprietary, undocumented
("plan_rev", discarded), # proprietary, undocumented
],
),
(
"transfers.txt",
[
# We expect that file to have no records.
# All None mappers will fail if it is not the case
('from_stop_id', None),
('to_stop_id', None),
('transfer_type', None),
('min_transfer_time', None),
('from_route_id', None),
('to_route_id', None),
]
),
]
def _get_file_names(etag: str):
dir = os.path.dirname(__file__)
return (
os.path.join(dir, etag + ".zip"),
os.path.join(dir, etag + ".sqlite"),
os.path.join(dir, "schema.sql"),
os.path.join(dir, "db.sqlite"),
)
def _fetch_dataset(response: http.HTTPResponse, dataset_file: str):
print("Fetching dataset...")
content_length = int(response.getheader("Content-Length"))
with open(dataset_file, "wb") as zip_out:
while True:
bytes = response.read(102400)
zip_out.write(bytes)
if not bytes:
break
progress = math.floor(100 * zip_out.tell() / content_length)
print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%")
def _check_dataset_files(zip_in: zipfile.ZipFile):
csv_files = list(sorted(zip_in.namelist()))
expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS))
assert len(expected) == len(csv_files), csv_files
assert all(a == b for a, b in zip(csv_files, expected, strict=True)), csv_files
def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]):
expected = list(header for header, _mapper in fields)
assert len(expected) == len(csv_headers), csv_headers
assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers
def _create_schema(db: sqlite3.Connection, schema_file: str):
with db:
with open(schema_file) as sql_in:
db.executescript(sql_in.read())
def _load_data(zip_in: zipfile.ZipFile, db: sqlite3.Connection):
for csv_file, fields in CSV_FIELDS:
table = csv_file[:-4]
print(f"Loading table {table!r}")
with zip_in.open(csv_file, "r") as csv_in:
reader = iter(
csv.reader(
io.TextIOWrapper(
csv_in,
encoding="utf-8",
newline="",
)
)
)
headers = next(reader)
_check_csv_headers(headers, fields)
place_holders = ",".join(
"?" for _field, mapper in fields if mapper not in (None, discarded)
)
if not place_holders:
try:
first_row = next(reader)
except StopIteration:
continue
raise NotImplementedError(list(zip(headers, first_row)))
def map_row(row: list[str]):
assert all(
not value
for (_field, mapper), value in zip(fields, row)
if mapper is None
)
return [
mapper(value)
for (_field, mapper), value in zip(fields, row)
if mapper not in (None, discarded)
]
with db:
db.executemany(
f"INSERT INTO {table} VALUES ({place_holders})",
(map_row(row) for row in reader),
)
def _load_database(dataset_file: str, database_tempfile: str, schema_file: str):
print("Loading database...")
with zipfile.ZipFile(dataset_file) as zip_in:
_check_dataset_files(zip_in)
with ctx.closing(sqlite3.connect(database_tempfile)) as db:
_create_schema(db, schema_file)
_load_data(zip_in, db)
print("Done")
def main():
print("Checking dataset")
response: http.HTTPResponse = request.urlopen(GTFS_URL)
if response.status != 200:
raise RuntimeError("Could not fetch the dataset")
etag = response.getheader("ETag")[1:-1]
(dataset_file, database_tempfile, schema_file, database_file) = _get_file_names(
etag
)
if os.path.isfile(dataset_file):
print("Dataset is up to date")
response.close()
else:
_fetch_dataset(response, dataset_file)
response.close()
if os.path.isfile(database_tempfile):
os.unlink(database_tempfile)
_load_database(dataset_file, database_tempfile, schema_file)
if os.path.isfile(database_file):
os.unlink(database_file)
os.rename(database_tempfile, database_file)
if __name__ == "__main__":
exit(main())