Skip to content
This repository was archived by the owner on Nov 21, 2023. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 85 additions & 28 deletions tools/convert_pkl_to_pb.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,41 @@ def unscope_name(name):


def reset_names(names):
for i in range(0, len(names)):
for i in range(len(names)):
names[i] = unscope_name(names[i])


def convert_collect_and_distribute(
op, blobs,
roi_canonical_scale,
roi_canonical_level,
roi_max_level,
roi_min_level,
rpn_max_level,
rpn_min_level,
rpn_post_nms_topN,
):
print('Converting CollectAndDistributeFpnRpnProposals'
' Python -> C++:\n{}'.format(op))
assert op.name.startswith('CollectAndDistributeFpnRpnProposalsOp'), \
'Not valid CollectAndDistributeFpnRpnProposalsOp'

inputs = [x for x in op.input]
ret = core.CreateOperator(
'CollectAndDistributeFpnRpnProposals',
inputs,
list(op.output),
roi_canonical_scale=roi_canonical_scale,
roi_canonical_level=roi_canonical_level,
roi_max_level=roi_max_level,
roi_min_level=roi_min_level,
rpn_max_level=rpn_max_level,
rpn_min_level=rpn_min_level,
rpn_post_nms_topN=rpn_post_nms_topN,
)
return ret


def convert_gen_proposals(
op, blobs,
rpn_pre_nms_topN,
Expand All @@ -131,19 +162,19 @@ def convert_gen_proposals(
rpn_min_size,
):
print('Converting GenerateProposals Python -> C++:\n{}'.format(op))
assert op.name.startswith("GenerateProposalsOp"), "Not valid GenerateProposalsOp"
assert op.name.startswith('GenerateProposalsOp'), 'Not valid GenerateProposalsOp'

spatial_scale = mutils.get_op_arg_valf(op, "spatial_scale", None)
spatial_scale = mutils.get_op_arg_valf(op, 'spatial_scale', None)
assert spatial_scale is not None

inputs = [x for x in op.input]
anchor_name = "anchor"
anchor_name = 'anchor'
inputs.append(anchor_name)
blobs[anchor_name] = get_anchors(spatial_scale)
print('anchors {}'.format(blobs[anchor_name]))

ret = core.CreateOperator(
"GenerateProposals",
'GenerateProposals',
inputs,
list(op.output),
spatial_scale=spatial_scale,
Expand All @@ -153,7 +184,6 @@ def convert_gen_proposals(
min_size=rpn_min_size,
correct_transform_coords=True,
)

return ret, anchor_name


Expand Down Expand Up @@ -183,36 +213,59 @@ def convert_op_name(op):
reset_names(op.output)
return [op]

@op_filter(type="Python", inputs=['rpn_cls_probs', 'rpn_bbox_pred', 'im_info'])
def convert_gen_proposal(op_in):
gen_proposals_op, ext_input = convert_gen_proposals(
op_in, blobs,
rpn_min_size=float(cfg.TEST.RPN_MIN_SIZE),
rpn_post_nms_topN=cfg.TEST.RPN_POST_NMS_TOP_N,
rpn_pre_nms_topN=cfg.TEST.RPN_PRE_NMS_TOP_N,
rpn_nms_thres=cfg.TEST.RPN_NMS_THRESH,
)
net.external_input.extend([ext_input])
return [gen_proposals_op]
@op_filter(type='Python')
def convert_python(op):
if op.name.startswith('GenerateProposalsOp'):
gen_proposals_op, ext_input = convert_gen_proposals(
op, blobs,
rpn_min_size=float(cfg.TEST.RPN_MIN_SIZE),
rpn_post_nms_topN=cfg.TEST.RPN_POST_NMS_TOP_N,
rpn_pre_nms_topN=cfg.TEST.RPN_PRE_NMS_TOP_N,
rpn_nms_thres=cfg.TEST.RPN_NMS_THRESH,
)
net.external_input.extend([ext_input])
return [gen_proposals_op]
elif op.name.startswith('CollectAndDistributeFpnRpnProposalsOp'):
collect_dist_op = convert_collect_and_distribute(
op, blobs,
roi_canonical_scale=cfg.FPN.ROI_CANONICAL_SCALE,
roi_canonical_level=cfg.FPN.ROI_CANONICAL_LEVEL,
roi_max_level=cfg.FPN.ROI_MAX_LEVEL,
roi_min_level=cfg.FPN.ROI_MIN_LEVEL,
rpn_max_level=cfg.FPN.RPN_MAX_LEVEL,
rpn_min_level=cfg.FPN.RPN_MIN_LEVEL,
rpn_post_nms_topN=cfg.TEST.RPN_POST_NMS_TOP_N,
)
return [collect_dist_op]
else:
raise ValueError('Failed to convert Python op {}'.format(
op.name))

@op_filter(input_has='rois')
@op_filter()
def convert_rpn_rois(op):
for j in range(0, len(op.input)):
for j in range(len(op.input)):
if op.input[j] == 'rois':
print('Converting op {} input name: rois -> rpn_rois:\n{}'.format(
op.type, op))
op.input[j] = 'rpn_rois'
for j in range(len(op.output)):
if op.output[j] == 'rois':
print('Converting op {} output name: rois -> rpn_rois:\n{}'.format(
op.type, op))
op.output[j] = 'rpn_rois'
return [op]

@op_filter(type_in=['StopGradient', 'Alias'])
def convert_remove_op(op):
print('Removing op {}:\n{}'.format(op.type, op))
return []

# We want to apply to all operators, including converted
# so run separately
convert_op_in_proto(net, convert_remove_op)
convert_op_in_proto(net, convert_python)
convert_op_in_proto(net, convert_op_name)
convert_op_in_proto(net, [
convert_gen_proposal, convert_rpn_rois, convert_remove_op
])
convert_op_in_proto(net, convert_rpn_rois)

reset_names(net.external_input)
reset_names(net.external_output)
Expand Down Expand Up @@ -267,6 +320,7 @@ def convert_model_gpu(args, net, init_net):
cdo_cpu = mutils.get_device_option_cpu()

CPU_OPS = [
["CollectAndDistributeFpnRpnProposals", None],
["GenerateProposals", None],
["BBoxTransform", None],
["BoxWithNMSLimit", None],
Expand Down Expand Up @@ -457,7 +511,7 @@ def run_model_pb(args, net, init_net, im, check_blobs):
)

try:
workspace.RunNet(net.Proto().name)
workspace.RunNet(net)
scores = workspace.FetchBlob('score_nms')
classids = workspace.FetchBlob('class_nms')
boxes = workspace.FetchBlob('bbox_nms')
Expand Down Expand Up @@ -515,13 +569,16 @@ def main():
merge_cfg_from_list(args.opts)
cfg.NUM_GPUS = 1
assert_and_infer_cfg()
logger.info('Conerting model with config:')
logger.info('Converting model with config:')
logger.info(pprint.pformat(cfg))

assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported."
assert not cfg.MODEL.MASK_ON, "Mask model not supported."
assert not cfg.FPN.FPN_ON, "FPN not supported."
assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported."
# script will stop when it can't find an operator rather
# than stopping based on these flags
#
# assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported."
# assert not cfg.MODEL.MASK_ON, "Mask model not supported."
# assert not cfg.FPN.FPN_ON, "FPN not supported."
# assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported."

# load model from cfg
model, blobs = load_model(args)
Expand Down