Fix import/typing in updates.py
This commit is contained in:
parent
77eb3d7444
commit
caf83485c7
|
@ -6,13 +6,10 @@ import urllib.request as request
|
|||
import os.path
|
||||
import math
|
||||
import contextlib as ctx
|
||||
import http.client as http
|
||||
|
||||
from typing import Callable, Any, Optional, TYPE_CHECKING
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from http.client import HTTPResponse
|
||||
from zipfile import ZipFile
|
||||
from sqlite3 import Connection as DBconnection
|
||||
|
||||
# GTFS reference:
|
||||
# https://gtfs.org/schedule/reference/
|
||||
|
@ -151,7 +148,7 @@ def _get_file_names(etag: str):
|
|||
)
|
||||
|
||||
|
||||
def _fetch_dataset(response: HTTPResponse, dataset_file: str):
|
||||
def _fetch_dataset(response: http.HTTPResponse, dataset_file: str):
|
||||
print("Fetching dataset...")
|
||||
content_length = int(response.getheader("Content-Length"))
|
||||
with open(dataset_file, "wb") as zip_out:
|
||||
|
@ -164,7 +161,7 @@ def _fetch_dataset(response: HTTPResponse, dataset_file: str):
|
|||
print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%")
|
||||
|
||||
|
||||
def _check_dataset_files(zip_in: ZipFile):
|
||||
def _check_dataset_files(zip_in: zipfile.ZipFile):
|
||||
csv_files = list(sorted(zip_in.namelist()))
|
||||
expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS))
|
||||
|
||||
|
@ -179,13 +176,13 @@ def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]):
|
|||
assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers
|
||||
|
||||
|
||||
def _create_schema(db: DBconnection, schema_file: str):
|
||||
def _create_schema(db: sqlite3.Connection, schema_file: str):
|
||||
with db:
|
||||
with open(schema_file) as sql_in:
|
||||
db.executescript(sql_in.read())
|
||||
|
||||
|
||||
def _load_data(zip_in: ZipFile, db: DBconnection):
|
||||
def _load_data(zip_in: zipfile.ZipFile, db: sqlite3.Connection):
|
||||
for csv_file, fields in CSV_FIELDS:
|
||||
if fields is None:
|
||||
continue
|
||||
|
@ -236,7 +233,7 @@ def _load_database(dataset_file: str, database_tempfile: str, schema_file: str):
|
|||
|
||||
def main():
|
||||
print("Checking dataset")
|
||||
response: HTTPResponse = request.urlopen(GTFS_URL)
|
||||
response: http.HTTPResponse = request.urlopen(GTFS_URL)
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Could not fetch the dataset")
|
||||
|
||||
|
|
Loading…
Reference in New Issue