diff --git a/data/schema.sql b/data/schema.sql index 62e8b62..c7ba6fb 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -81,4 +81,16 @@ CREATE TABLE CREATE INDEX stop_times__by_trip_id ON stop_times (trip_id); CREATE INDEX stop_times__by_stop_id - ON stop_times (stop_id); \ No newline at end of file + ON stop_times (stop_id); + +CREATE TABLE + feed_info( + feed_id INT NOT NULL, + feed_publisher_name TEXT NOT NULL, + feed_publisher_url TEXT NOT NULL, + feed_lang TEXT NOT NULL, + feed_start_date TEXT NOT NULL, + feed_end_date TEXT NOT NULL, + feed_version TEXT NOT NULL, + PRIMARY KEY (feed_id) + ); \ No newline at end of file diff --git a/data/update.py b/data/update.py index c2e1b3b..708062f 100644 --- a/data/update.py +++ b/data/update.py @@ -8,7 +8,7 @@ import math import contextlib as ctx import http.client as http -from typing import Callable, Any, Optional +from typing import Callable, Any # GTFS reference: @@ -60,7 +60,12 @@ def optional(f: Mapper) -> Mapper: return map -CSV_FIELDS: list[tuple[str, Optional[list[tuple[str, Optional[Mapper]]]]]] = [ +def discarded(v: str): + # No-op mapper for field intentionally not included in the schema + return None + + +CSV_FIELDS: list[tuple[str, list[tuple[str, Mapper | None]]]] = [ ( "agency.txt", [ @@ -133,8 +138,33 @@ CSV_FIELDS: list[tuple[str, Optional[list[tuple[str, Optional[Mapper]]]]]] = [ ("shape_dist_traveled", None), ], ), - ("feed_info.txt", None), - ("transfers.txt", None), + ( + "feed_info.txt", + [ + ("feed_id", int), + ("feed_publisher_name", str), + ("feed_publisher_url", str), + ("feed_lang", str), + ("feed_start_date", str), + ("feed_end_date", str), + ("feed_version", str), + ("conv_rev", discarded), # proprietary, undocumented + ("plan_rev", discarded), # proprietary, undocumented + ], + ), + ( + "transfers.txt", + [ + # We expect that file to have no records. + # All None mappers will fail if it is not the case + ('from_stop_id', None), + ('to_stop_id', None), + ('transfer_type', None), + ('min_transfer_time', None), + ('from_route_id', None), + ('to_route_id', None), + ] + ), ] @@ -184,8 +214,6 @@ def _create_schema(db: sqlite3.Connection, schema_file: str): def _load_data(zip_in: zipfile.ZipFile, db: sqlite3.Connection): for csv_file, fields in CSV_FIELDS: - if fields is None: - continue table = csv_file[:-4] print(f"Loading table {table!r}") @@ -200,22 +228,38 @@ def _load_data(zip_in: zipfile.ZipFile, db: sqlite3.Connection): ) ) ) - _check_csv_headers(next(reader), fields) + headers = next(reader) + _check_csv_headers(headers, fields) place_holders = ",".join( - "?" for _field, mapper in fields if mapper is not None + "?" for _field, mapper in fields if mapper not in (None, discarded) ) + + if not place_holders: + try: + first_row = next(reader) + except StopIteration: + continue + + raise NotImplementedError(list(zip(headers, first_row))) + + def map_row(row: list[str]): + assert all( + not value + for (_field, mapper), value in zip(fields, row) + if mapper is None + ) + + return [ + mapper(value) + for (_field, mapper), value in zip(fields, row) + if mapper not in (None, discarded) + ] + with db: db.executemany( f"INSERT INTO {table} VALUES ({place_holders})", - ( - [ - mapper(value) - for (_field, mapper), value in zip(fields, row) - if mapper is not None - ] - for row in reader - ), + (map_row(row) for row in reader), )