Skip to content

Commit

Permalink
Add cli flag to suppress pull progress
Browse files Browse the repository at this point in the history
Related: containers#684
Signed-off-by: Michael Kesper <[email protected]>
  • Loading branch information
mkesper committed Feb 8, 2025
1 parent 738eda2 commit 1628c03
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 13 deletions.
3 changes: 3 additions & 0 deletions docs/ramalama-pull.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ Print usage message
#### **--tls-verify**=*true*
require HTTPS and verify certificates when contacting OCI registries

#### **--quiet**, **-q**
Do not show progress bar

## SEE ALSO
**[ramalama(1)](ramalama.1.md)**

Expand Down
7 changes: 7 additions & 0 deletions ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,13 @@ def pull_parser(subparsers):
default=True,
help="require HTTPS and verify certificates when contacting registries",
)
parser.add_argument(
"-q",
"--quiet",
default=False,
action="store_true",
help="Do not show progress bar during download",
)
parser.add_argument("MODEL") # positional argument
parser.set_defaults(func=pull_cli)

Expand Down
2 changes: 1 addition & 1 deletion ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def download_file(url, dest_path, headers=None, show_progress=True):
while retries < max_retries:
try:
# Initialize HTTP client for the request
http_client.init(url=url, headers=headers, output_file=dest_path, progress=show_progress)
http_client.init(url=url, headers=headers, output_file=dest_path, show_progress=show_progress)
return # Exit function if successful

except urllib.error.HTTPError as e:
Expand Down
8 changes: 4 additions & 4 deletions ramalama/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class HttpClient:
def __init__(self):
pass

def init(self, url, headers, output_file, progress, response_str=None):
def init(self, url, headers, output_file, show_progress, response_str=None):
output_file_partial = None
if output_file:
output_file_partial = output_file + ".partial"
Expand All @@ -32,7 +32,7 @@ def init(self, url, headers, output_file, progress, response_str=None):

self.now_downloaded = 0
self.start_time = time.time()
self.perform_download(out.file, progress)
self.perform_download(out.file, show_progress)

if output_file:
os.rename(output_file_partial, output_file)
Expand All @@ -50,7 +50,7 @@ def urlopen(self, url, headers):
if self.response.status not in (200, 206):
raise IOError(f"Request failed: {self.response.status}")

def perform_download(self, file, progress):
def perform_download(self, file, show_progress):
self.total_to_download += self.file_size
self.now_downloaded = 0
self.start_time = time.time()
Expand All @@ -62,7 +62,7 @@ def perform_download(self, file, progress):
break

size = file.write(data)
if progress:
if show_progress:
accumulated_size += size
if time.time() - last_update_time >= 0.1:
self.update_progress(accumulated_size)
Expand Down
2 changes: 2 additions & 0 deletions ramalama/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ def pull(self, args):
raise NotImplementedError("OCI images require a container engine like Podman or Docker")

conman_args = [args.engine, "pull"]
if args.quiet:
conman_args.extend(['--quiet'])
if str(args.tlsverify).lower() == "false":
conman_args.extend([f"--tls-verify={args.tlsverify}"])
if args.authfile:
Expand Down
17 changes: 10 additions & 7 deletions ramalama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def fetch_manifest_data(registry_head, model_tag, accept):
return manifest_data


def pull_config_blob(repos, accept, registry_head, manifest_data):
def pull_config_blob(repos, accept, registry_head, manifest_data, show_progress):
cfg_hash = manifest_data["config"]["digest"]
config_blob_path = os.path.join(repos, "blobs", cfg_hash)

Expand All @@ -26,11 +26,11 @@ def pull_config_blob(repos, accept, registry_head, manifest_data):
download_file(url, config_blob_path, headers=headers, show_progress=False)


def pull_blob(repos, layer_digest, accept, registry_head, models, model_name, model_tag, model_path):
def pull_blob(repos, layer_digest, accept, registry_head, models, model_name, model_tag, model_path, show_progress):
layer_blob_path = os.path.join(repos, "blobs", layer_digest)
url = f"{registry_head}/blobs/{layer_digest}"
headers = {"Accept": accept}
download_file(url, layer_blob_path, headers=headers, show_progress=True)
download_file(url, layer_blob_path, headers=headers, show_progress=show_progress)

# Verify checksum after downloading the blob
if not verify_checksum(layer_blob_path):
Expand All @@ -45,15 +45,15 @@ def pull_blob(repos, layer_digest, accept, registry_head, models, model_name, mo
run_cmd(["ln", "-sf", relative_target_path, model_path])


def init_pull(repos, accept, registry_head, model_name, model_tag, models, model_path, model):
def init_pull(repos, accept, registry_head, model_name, model_tag, models, model_path, model, show_progress):
manifest_data = fetch_manifest_data(registry_head, model_tag, accept)
pull_config_blob(repos, accept, registry_head, manifest_data)
pull_config_blob(repos, accept, registry_head, manifest_data, show_progress)
for layer in manifest_data["layers"]:
layer_digest = layer["digest"]
if layer["mediaType"] != "application/vnd.ollama.image.model":
continue

pull_blob(repos, layer_digest, accept, registry_head, models, model_name, model_tag, model_path)
pull_blob(repos, layer_digest, accept, registry_head, models, model_name, model_tag, model_path, show_progress)

return model_path

Expand Down Expand Up @@ -101,11 +101,14 @@ def pull(self, args):
if os.path.exists(model_path):
return model_path

show_progress = not args.quiet
registry = "https://registry.ollama.ai"
accept = "Accept: application/vnd.docker.distribution.manifest.v2+json"
registry_head = f"{registry}/v2/{model_name}"
try:
return init_pull(repos, accept, registry_head, model_name, model_tag, models, model_path, self.model)
return init_pull(
repos, accept, registry_head, model_name, model_tag, models, model_path, self.model, show_progress
)
except urllib.error.HTTPError as e:
if "Not Found" in e.reason:
raise KeyError(f"{self.model} was not found in the Ollama registry")
Expand Down
3 changes: 2 additions & 1 deletion ramalama/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def pull(self, args):
os.symlink(self.model, os.path.join(symlink_dir, self.filename))
os.symlink(self.model, target_path)
else:
show_progress = not args.quiet
url = self.type + "://" + self.model
# Download the model file to the target path
download_file(url, target_path, headers={}, show_progress=True)
download_file(url, target_path, headers={}, show_progress=show_progress)
relative_target_path = os.path.relpath(target_path, start=os.path.dirname(model_path))
if self.check_valid_model_path(relative_target_path, model_path):
# Symlink is already correct, no need to update it
Expand Down

0 comments on commit 1628c03

Please sign in to comment.