Fix import/typing in updates.py

This commit is contained in:
Barbagus42 2023-11-07 09:32:08 +01:00
parent 77eb3d7444
commit caf83485c7
1 changed files with 7 additions and 10 deletions

View File

@ -6,13 +6,10 @@ import urllib.request as request
import os.path
import math
import contextlib as ctx
import http.client as http
from typing import Callable, Any, Optional, TYPE_CHECKING
from typing import Callable, Any, Optional
if TYPE_CHECKING:
from http.client import HTTPResponse
from zipfile import ZipFile
from sqlite3 import Connection as DBconnection
# GTFS reference:
# https://gtfs.org/schedule/reference/
@ -151,7 +148,7 @@ def _get_file_names(etag: str):
)
def _fetch_dataset(response: HTTPResponse, dataset_file: str):
def _fetch_dataset(response: http.HTTPResponse, dataset_file: str):
print("Fetching dataset...")
content_length = int(response.getheader("Content-Length"))
with open(dataset_file, "wb") as zip_out:
@ -164,7 +161,7 @@ def _fetch_dataset(response: HTTPResponse, dataset_file: str):
print(f"Fetched: {zip_out.tell()}/{content_length} {progress}%")
def _check_dataset_files(zip_in: ZipFile):
def _check_dataset_files(zip_in: zipfile.ZipFile):
csv_files = list(sorted(zip_in.namelist()))
expected = list(sorted(csv_file for csv_file, _ in CSV_FIELDS))
@ -179,13 +176,13 @@ def _check_csv_headers(csv_headers: list[str], fields: list[tuple[str, Any]]):
assert all(a == b for a, b in zip(csv_headers, expected)), csv_headers
def _create_schema(db: DBconnection, schema_file: str):
def _create_schema(db: sqlite3.Connection, schema_file: str):
with db:
with open(schema_file) as sql_in:
db.executescript(sql_in.read())
def _load_data(zip_in: ZipFile, db: DBconnection):
def _load_data(zip_in: zipfile.ZipFile, db: sqlite3.Connection):
for csv_file, fields in CSV_FIELDS:
if fields is None:
continue
@ -236,7 +233,7 @@ def _load_database(dataset_file: str, database_tempfile: str, schema_file: str):
def main():
print("Checking dataset")
response: HTTPResponse = request.urlopen(GTFS_URL)
response: http.HTTPResponse = request.urlopen(GTFS_URL)
if response.status != 200:
raise RuntimeError("Could not fetch the dataset")