-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathrun_inference_on_fw.py
More file actions
executable file
·134 lines (103 loc) · 5.24 KB
/
run_inference_on_fw.py
File metadata and controls
executable file
·134 lines (103 loc) · 5.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
To infer the model on framework runtime, you can use `run_inference_on_fw.py` script.
It infers data obtained from pointed data loader locally and saves received data into npz files.
Those files are stored in directory pointed by `--output-dir` argument.
Example call:
```shell script
python ./triton/run_inference_on_fw.py \
--input-path /models/exported/model.onnx \
--input-type onnx \
--dataloader triton/dataloader.py \
--data-dir /data/imagenet \
--batch-size 32 \
--output-dir /results/dump_local \
--dump-labels
```
"""
import argparse
import logging
import os
from pathlib import Path
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_DEPRECATION_WARNINGS"] = "0"
from tqdm import tqdm
# method from PEP-366 to support relative import in executed modules
if __package__ is None:
__package__ = Path(__file__).parent.name
from .deployment_toolkit.args import ArgParserGenerator
from .deployment_toolkit.core import DATALOADER_FN_NAME, BaseLoader, BaseRunner, Format, load_from_file
from .deployment_toolkit.dump import NpzWriter
from .deployment_toolkit.extensions import loaders, runners
LOGGER = logging.getLogger("run_inference_on_fw")
def _verify_and_format_dump(args, ids, x, y_pred, y_real):
data = {"outputs": y_pred, "ids": {"ids": ids}}
if args.dump_inputs:
data["inputs"] = x
if args.dump_labels:
if not y_real:
raise ValueError(
"Found empty label values. Please provide labels in dataloader_fn or do not use --dump-labels argument"
)
data["labels"] = y_real
return data
def _parse_and_validate_args():
supported_inputs = set(runners.supported_extensions) & set(loaders.supported_extensions)
parser = argparse.ArgumentParser(description="Dump local inference output of given model", allow_abbrev=False)
parser.add_argument("--input-path", help="Path to input model", required=True)
parser.add_argument("--input-type", help="Input model type", choices=supported_inputs, required=True)
parser.add_argument("--dataloader", help="Path to python file containing dataloader.", required=True)
parser.add_argument("--output-dir", help="Path to dir where output files will be stored", required=True)
parser.add_argument("--dump-labels", help="Dump labels to output dir", action="store_true", default=False)
parser.add_argument("--dump-inputs", help="Dump inputs to output dir", action="store_true", default=False)
parser.add_argument("-v", "--verbose", help="Verbose logs", action="store_true", default=False)
args, *_ = parser.parse_known_args()
get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
ArgParserGenerator(get_dataloader_fn).update_argparser(parser)
Loader: BaseLoader = loaders.get(args.input_type)
ArgParserGenerator(Loader, module_path=args.input_path).update_argparser(parser)
Runner: BaseRunner = runners.get(args.input_type)
ArgParserGenerator(Runner).update_argparser(parser)
args = parser.parse_args()
types_requiring_io_params = []
if args.input_type in types_requiring_io_params and not all(p for p in [args.inputs, args.outputs]):
parser.error(f"For {args.input_type} input provide --inputs and --outputs parameters")
return args
def main():
args = _parse_and_validate_args()
log_level = logging.INFO if not args.verbose else logging.DEBUG
log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
logging.basicConfig(level=log_level, format=log_format)
LOGGER.info(f"args:")
for key, value in vars(args).items():
LOGGER.info(f" {key} = {value}")
Loader: BaseLoader = loaders.get(args.input_type)
Runner: BaseRunner = runners.get(args.input_type)
loader = ArgParserGenerator(Loader, module_path=args.input_path).from_args(args)
runner = ArgParserGenerator(Runner).from_args(args)
LOGGER.info(f"Loading {args.input_path}")
model = loader.load(args.input_path)
with runner.init_inference(model=model) as runner_session, NpzWriter(args.output_dir) as writer:
get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)
LOGGER.info(f"Data loader initialized; Running inference")
for ids, x, y_real in tqdm(dataloader_fn(), unit="batch", mininterval=10):
y_pred = runner_session(x)
data = _verify_and_format_dump(args, ids=ids, x=x, y_pred=y_pred, y_real=y_real)
writer.write(**data)
LOGGER.info(f"Inference finished")
if __name__ == "__main__":
main()