Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 40 additions & 120 deletions main_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,45 @@
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import time
from typing import Optional

# main_conversion.py
"""
Model export utilities.
- Exports to TorchScript (script or trace)
- Notes how to export to CoreML using coremltools (optional)
Usage:
python main_conversion.py --ckpt checkpoints/best...pth --out-model models/mobilevit_malaria_ts.pt
"""
import argparse
import torch
from torch.cuda.amp import autocast

from cvnets import get_model
from engine.utils import autocast_fn
from options.opts import get_benchmarking_arguments
from utils import logger
from utils.common_utils import device_setup
from utils.pytorch_to_coreml import convert_pytorch_to_coreml
from utils.tensor_utils import create_rand_tensor


def cpu_timestamp(*args, **kwargs):
# perf_counter returns time in seconds
return time.perf_counter()


def cuda_timestamp(cuda_sync=False, device=None, *args, **kwargs):
if cuda_sync:
torch.cuda.synchronize(device=device)
# perf_counter returns time in seconds
return time.perf_counter()


def step(
time_fn,
model,
example_inputs,
autocast_enable: False,
amp_precision: Optional[str] = "float16",
):
start_time = time_fn()
with autocast_fn(enabled=autocast_enable, amp_precision=amp_precision):
model(example_inputs)
end_time = time_fn(cuda_sync=True)
return end_time - start_time


def main_benchmark():
# set-up
opts = get_benchmarking_arguments()
# device set-up
opts = device_setup(opts)

norm_layer = getattr(opts, "model.normalization.name", "batch_norm")
if norm_layer.find("sync") > -1:
norm_layer = norm_layer.replace("sync_", "")
setattr(opts, "model.normalization.name", norm_layer)
device = getattr(opts, "dev.device", torch.device("cpu"))
if torch.cuda.device_count() == 0:
device = torch.device("cpu")
time_fn = cpu_timestamp if device == torch.device("cpu") else cuda_timestamp
warmup_iterations = getattr(opts, "benchmark.warmup_iter", 10)
iterations = getattr(opts, "benchmark.n_iter", 50)
batch_size = getattr(opts, "benchmark.batch_size", 1)
mixed_precision = (
False
if device == torch.device("cpu")
else getattr(opts, "common.mixed_precision", False)
)
mixed_precision_dtype = getattr(opts, "common.mixed_precision_dtype", "float16")

# load the model
model = get_model(opts)
model.eval()
# print model information
model.info()

example_inp = create_rand_tensor(opts=opts, device="cpu", batch_size=batch_size)

# cool down for 5 seconds
time.sleep(5)

if getattr(opts, "benchmark.use_jit_model", False):
converted_models_dict = convert_pytorch_to_coreml(
opts=None,
pytorch_model=model,
input_tensor=example_inp,
jit_model_only=True,
)
model = converted_models_dict["jit"]
model = model.to(device=device)
example_inp = example_inp.to(device=device)
from pathlib import Path
from main_train import EnhancedMobileViT

def export_torchscript(ckpt_path, out_path, img_size=224):
device = torch.device("cpu")
model = EnhancedMobileViT(num_classes=2, img_size=img_size, pretrained=False, cbam=False, fusion=False)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state.get("model_state", state))
model.eval()
example = torch.randn(1, 3, img_size, img_size)
# Trace the model
traced = torch.jit.trace(model, example)
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
traced.save(out_path)
print(f"Saved TorchScript traced model to: {out_path}")

with torch.no_grad():
# warm-up
for i in range(warmup_iterations):
step(
time_fn=time_fn,
model=model,
example_inputs=example_inp,
autocast_enable=mixed_precision,
amp_precision=mixed_precision_dtype,
)

n_steps = n_samples = 0.0

# run benchmark
for i in range(iterations):
step_time = step(
time_fn=time_fn,
model=model,
example_inputs=example_inp,
autocast_enable=mixed_precision,
amp_precision=mixed_precision_dtype,
)
n_steps += step_time
n_samples += batch_size

logger.info(
"Number of samples processed per second: {:.3f}".format(n_samples / n_steps)
)

def main(args):
export_torchscript(args.ckpt, args.out)

if __name__ == "__main__":
main_benchmark()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--out", type=str, required=True)
parser.add_argument("--img-size", type=int, default=224)
args = parser.parse_args()
main(args)

# NOTE: CoreML conversion (optional)
# If you want CoreML, see coremltools and use:
# import coremltools as ct
# traced = torch.jit.load("model_ts.pt")
# mlmodel = ct.convert(traced, inputs=[ct.ImageType(name="input_1", shape=(1,3,224,224), scale=1/255.0)])
# Save mlmodel: mlmodel.save("MobileViTMalaria.mlmodel")
#
# Make sure to test inference numerically after conversion.