import zipfile
import csv
import io
import sqlite3
import urllib.request as request
import os.path
import math
import contextlib as ctx
import http.client as http
from typing import Callable, Any
# GTFS reference:
# https://gtfs.org/schedule/reference/
# SNCF/TER dataset information:
# https://data.opendatasoft.com/explore/dataset/sncf-ter-gtfs%40datasncf/information/
# SNCF/TER daily updated dataset:
GTFS_URL = "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip"
Cell = int | str | float | None
Mapper = Callable[[Any], Cell]
# This data set uses large strings as primary/foreign keys that do not carry
# information. In order to save some space in the database and and time during
# requests, we substitute those large IDs with incrementing integers.
# Global register for primary key substitution
PKS: dict[str, dict[str, int]] = {}
# Primary key substitution requester
def primary_key(table: str) -> Mapper:
assert table not in PKS
PKS[table] = {}
def map(v: str):
PKS[table][v] = len(PKS[table]) + 1
return len(PKS[table])
return map
# Foreign key lookup
def foreign_key(table: str) -> Mapper:
def map(v: str):
return PKS[table][v]
return map
# A "can be null" mapper wrapper
def optional(f: Mapper) -> Mapper:
def map(v: str):
return None if v == "" else f(v)
return map
def discarded(v: str):
# No-op mapper for field intentionally not included in the schema
return None
CSV_FIELDS: list[tuple[str, list[tuple[str, Mapper | None]]]] = [
("agency_id", primary_key("agency")),
("agency_name", str),
("agency_url", str),
("agency_timezone", str),
("agency_lang", str),
("service_id", int),
("date", str),
("exception_type", int),
("route_id", primary_key("routes")),
("agency_id", foreign_key("agency")),
("route_short_name", str),
("route_long_name", str),
("route_desc", None),
("route_type", int),
("route_url", None),
("route_color", optional(str)),
("route_text_color", optional(str)),
("route_id", foreign_key("routes")),
("service_id", int),
("trip_id", primary_key("trips")),
("trip_headsign", str),
("direction_id", optional(int)),
("block_id", int),
("shape_id", None),
("stop_id", primary_key("stops")),
("stop_name", str),
("stop_desc", None),
("stop_lat", float),
("stop_lon", float),
("zone_id", None),
("stop_url", None),
("location_type", int),
("parent_station", optional(foreign_key("stops"))),
("trip_id", foreign_key("trips")),
("arrival_time", str),
("departure_time", str),
("stop_id", foreign_key("stops")),
("stop_sequence", int),
("stop_headsign", None),
("pickup_type", int),
("drop_off_type", int),
("shape_dist_traveled", None),
("feed_id", int),
("feed_publisher_name", str),
("feed_publisher_url", str),
("feed_lang", str),
("feed_start_date", str),
("feed_end_date", str),
("feed_version", str),
("conv_rev", discarded), # proprietary, undocumented
("plan_rev", discarded), # proprietary, undocumented
# We expect that file to have no records.
# All None mappers will fail if it is not the case
('from_stop_id', None),
('to_stop_id', None),
('transfer_type', None),
('min_transfer_time', None),
('from_route_id', None),
('to_route_id', None),
def _get_file_names(etag: str):
dir = os.path.dirname(__file__)
return (
os.path.join(dir, etag + ".zip"),
os.path.join(dir, etag + ".sqlite"),
os.path.join(dir, "schema.sql"),
os.path.join(dir, "db.sqlite"),
def _fetch_dataset(response: http.HTTPResponse, dataset_file: str):
print("Fetching dataset...")
content_length = int(response.getheader("Content-Length"))
with open(dataset_file, "wb") as zip_out:
while True:
bytes = response.read(102400)
if not bytes:
progress = math.floor(100 * zip_out.tell() / content_length)
print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%")
def _check_dataset_files(zip_in: zipfile.ZipFile):
csv_files = list(sorted(zip_in.namelist()))
expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS))
assert len(expected) == len(csv_files), csv_files
assert all(a == b for a, b in zip(csv_files, expected, strict=True)), csv_files
def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]):
expected = list(header for header, _mapper in fields)
assert len(expected) == len(csv_headers), csv_headers
assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers
def _create_schema(db: sqlite3.Connection, schema_file: str):
with db:
with open(schema_file) as sql_in:
def _load_data(zip_in: zipfile.ZipFile, db: sqlite3.Connection):
for csv_file, fields in CSV_FIELDS:
table = csv_file[:-4]
print(f"Loading table {table!r}")
with zip_in.open(csv_file, "r") as csv_in:
reader = iter(
headers = next(reader)
_check_csv_headers(headers, fields)
place_holders = ",".join(
"?" for _field, mapper in fields if mapper not in (None, discarded)
if not place_holders:
first_row = next(reader)
except StopIteration:
raise NotImplementedError(list(zip(headers, first_row)))
def map_row(row: list[str]):
assert all(
not value
for (_field, mapper), value in zip(fields, row)
if mapper is None
return [
for (_field, mapper), value in zip(fields, row)
if mapper not in (None, discarded)
with db:
f"INSERT INTO {table} VALUES ({place_holders})",
(map_row(row) for row in reader),
def _load_database(dataset_file: str, database_tempfile: str, schema_file: str):
print("Loading database...")
with zipfile.ZipFile(dataset_file) as zip_in:
with ctx.closing(sqlite3.connect(database_tempfile)) as db:
_create_schema(db, schema_file)
_load_data(zip_in, db)
def main():
print("Checking dataset")
response: http.HTTPResponse = request.urlopen(GTFS_URL)
if response.status != 200:
raise RuntimeError("Could not fetch the dataset")
etag = response.getheader("ETag")[1:-1]
(dataset_file, database_tempfile, schema_file, database_file) = _get_file_names(
if os.path.isfile(dataset_file):
print("Dataset is up to date")
_fetch_dataset(response, dataset_file)
if os.path.isfile(database_tempfile):
_load_database(dataset_file, database_tempfile, schema_file)
if os.path.isfile(database_file):
os.rename(database_tempfile, database_file)
if __name__ == "__main__":