diff --git a/src/wireviz/svgembed.py b/src/wireviz/svgembed.py index 69961dd..ea802ca 100644 --- a/src/wireviz/svgembed.py +++ b/src/wireviz/svgembed.py @@ -4,6 +4,7 @@ import re import base64 from pathlib import Path +from typing import Union mime_subtype_replacements = {'jpg': 'jpeg', 'tif': 'tiff'} @@ -25,15 +26,15 @@ def embed_svg_images(svg_in: str, base_path: Path): return svg_out -def get_mime_subtype(filename: Path): - mime_subtype = filename.suffix.lstrip('.').lower() +def get_mime_subtype(filename: Union[str, Path]): + mime_subtype = Path(filename).suffix.lstrip('.').lower() if mime_subtype in mime_subtype_replacements: mime_subtype = mime_subtype_replacements[mime_subtype] return mime_subtype -def embed_svg_images_file(filename_in: Path, overwrite: bool = True): - filename_in = filename_in.resolve() +def embed_svg_images_file(filename_in: Union[str, Path], overwrite: bool = True): + filename_in = Path(filename_in).resolve() filename_out = filename_in.with_suffix('.b64.svg') filename_out.write_text(embed_svg_images(filename_in.read_text(), filename_in.parent)) if overwrite: