Skip to content

[Bug fix]: predict error on multi-gpu#1082

Open
kawa23 wants to merge 1 commit intomatterport:masterfrom
kawa23:master
Open

[Bug fix]: predict error on multi-gpu#1082
kawa23 wants to merge 1 commit intomatterport:masterfrom
kawa23:master

Conversation

@kawa23
Copy link
Copy Markdown

@kawa23 kawa23 commented Oct 27, 2018

Predict problem on multi gpu (Input to reshape is a tensor with 24000 values, but the requested has 48000)

I thought this issue, see also:( #1044 )caused by DetectionLayer output reshape

https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/model.py#L820

 return tf.reshape(
            detections_batch,
            [self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])

was called before parallel_model merge at the detection func
https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/model.py#L2043

so I think, the DetectionLayer output reshape was error, and it should be changed like that,

 return tf.reshape(
            detections_batch,
            [self.config.IMAGES_PER_GPU, self.config.DETECTION_MAX_INSTANCES, 6])

it worked for me on mulit-gpu(two)

@waleedka
Copy link
Copy Markdown
Collaborator

@kawa23 Thanks for the pull request. I'm reviewing it and trying to trace the issue. A quick question about the error you reported in #1044: How come your tensor shape is 24000? I'd expect it to be 2400 since your batch size is 4.

BATCH_SIZE * DETECTION_MAX_INSTANCES * 6

Did you happen to change DETECTION_MAX_INSTANCES from the default value of 100?

I acknowledge the error you got, but I'm suspecting that the issue might be somewhere else.

@kawa23
Copy link
Copy Markdown
Author

kawa23 commented Oct 28, 2018

@kawa23 Thanks for the pull request. I'm reviewing it and trying to trace the issue. A quick question about the error you reported in #1044: How come your tensor shape is 24000? I'd expect it to be 2400 since your batch size is 4.

BATCH_SIZE * DETECTION_MAX_INSTANCES * 6

Did you happen to change DETECTION_MAX_INSTANCES from the default value of 100?

I acknowledge the error you got, but I'm suspecting that the issue might be somewhere else.

Yes, When I changed the DETECTION_MAX_INSTANCES from the default value of 100 to 2000, I got this error:Input to reshape is a tensor with 24000 values, but the requested has 48000,

and the DETECTION_MAX_INSTANCES was 100, it came out Input to reshape is a tensor with 1200 values, but the requested has 2400.

I found that the request reshape is twice as input reshape based on several other data tests, the same as my GPU_COUNT=2 setting.

And I also found that it was here #L819 ~ #L821

return tf.reshape( detections_batch, [self.config.IMAGES_PER_GPU, self.config.DETECTION_MAX_INSTANCES, 6])

throws the error when I tracking it.

May be I should test on GPU_COUNT=3 or more, but I only have two.

Thanks for your reply, and I'll appreciate it if you find out and fix it.

@keineahnung2345
Copy link
Copy Markdown
Contributor

@kawa23 Our company has a DGX-1, and I have tested your code from 1 to 8 GPUs. When I tested on 1 to 4 GPU, everything works fine. But when I tested on 5,6,7 GPUs, it gave me the following error:

# set gpu_count=5 (similar error when gpu_count = 6 or 7)
InvalidArgumentErrorTraceback (most recent call last)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1321     try:
-> 1322       return fn(*args)
   1323     except errors.OpError as e:

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1306       return self._call_tf_sessionrun(
-> 1307           options, feed_dict, fetch_list, target_list, run_metadata)
   1308 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
   1408           self._session, options, feed_dict, fetch_list, target_list,
-> 1409           run_metadata)
   1410     else:

InvalidArgumentError: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 32) and num_split 5
	 [[Node: split_4 = Split[T=DT_FLOAT, num_split=5, _device="/job:localhost/replica:0/task:0/device:GPU:0"](tower_0_1/mask_rcnn/mrcnn_mask_deconv/add/y, _arg_input_image_meta_1_0_2/_25023)]]
	 [[Node: tower_4_1/mask_rcnn/mrcnn_detection/map_5/while/LoopCond/_34303 = _HostRecv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:4", send_device_incarnation=1, tensor_name="edge_39553_tower_4_1/mask_rcnn/mrcnn_detection/map_5/while/LoopCond", tensor_type=DT_BOOL, _device="/job:localhost/replica:0/task:0/device:CPU:0"](^_clooptower_4_1/mask_rcnn/mrcnn_detection/map_5/while/TensorArrayReadV3/_24766)]]

During handling of the above exception, another exception occurred:

InvalidArgumentErrorTraceback (most recent call last)
<ipython-input-28-94f99786ace2> in <module>()
     33     # Run object detection
     34     # it can only accept input with the size equal to INFERENCE_BATCH_SIZE
---> 35     results = model.detect(images, verbose=0)
     36     # delete the result from dummy images
     37     results = results[:cur_batch_size]

/notebooks/Lorenzo/Mask_RCNN/mrcnn/model_test_1082.py in detect(self, images, verbose)
   2578         # Run object detection
   2579         detections, _, _, mrcnn_mask, _, _, _ =\
-> 2580             self.keras_model.predict([molded_images, image_metas, anchors], verbose=0)
   2581         # Process detections
   2582         results = []

/usr/local/lib/python3.5/dist-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
   1798         f = self.predict_function
   1799         return self._predict_loop(f, ins, batch_size=batch_size,
-> 1800                                   verbose=verbose, steps=steps)
   1801 
   1802     def train_on_batch(self, x, y,

/usr/local/lib/python3.5/dist-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose, steps)
   1299                     ins_batch[i] = ins_batch[i].toarray()
   1300 
-> 1301                 batch_outs = f(ins_batch)
   1302                 if not isinstance(batch_outs, list):
   1303                     batch_outs = [batch_outs]

/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2473         session = get_session()
   2474         updated = session.run(fetches=fetches, feed_dict=feed_dict,
-> 2475                               **self.session_kwargs)
   2476         return updated[:len(self.outputs)]
   2477 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    898     try:
    899       result = self._run(None, fetches, feed_dict, options_ptr,
--> 900                          run_metadata_ptr)
    901       if run_metadata:
    902         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1133     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1134       results = self._do_run(handle, final_targets, final_fetches,
-> 1135                              feed_dict_tensor, options, run_metadata)
   1136     else:
   1137       results = []

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1314     if handle is None:
   1315       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1316                            run_metadata)
   1317     else:
   1318       return self._do_call(_prun_fn, handle, feeds, fetches)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1333         except KeyError:
   1334           pass
-> 1335       raise type(e)(node_def, op, message)
   1336 
   1337   def _extend_graph(self):

InvalidArgumentError: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 32) and num_split 5
	 [[Node: split_4 = Split[T=DT_FLOAT, num_split=5, _device="/job:localhost/replica:0/task:0/device:GPU:0"](tower_0_1/mask_rcnn/mrcnn_mask_deconv/add/y, _arg_input_image_meta_1_0_2/_25023)]]
	 [[Node: tower_4_1/mask_rcnn/mrcnn_detection/map_5/while/LoopCond/_34303 = _HostRecv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:4", send_device_incarnation=1, tensor_name="edge_39553_tower_4_1/mask_rcnn/mrcnn_detection/map_5/while/LoopCond", tensor_type=DT_BOOL, _device="/job:localhost/replica:0/task:0/device:CPU:0"](^_clooptower_4_1/mask_rcnn/mrcnn_detection/map_5/while/TensorArrayReadV3/_24766)]]

Caused by op 'split_4', defined at:
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.5/dist-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelapp.py", line 499, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.5/dist-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 345, in run_forever
    self._run_once()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 1312, in _run_once
    handle._run()
  File "/usr/lib/python3.5/asyncio/events.py", line 125, in _run
    self._callback(*self._args)
  File "/usr/local/lib/python3.5/dist-packages/tornado/ioloop.py", line 758, in _run_callback
    ret = callback()
  File "/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 1233, in inner
    self.run()
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 1147, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 346, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 259, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 513, in execute_request
    user_expressions, allow_stdin,
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/ipkernel.py", line 294, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2901, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2961, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-27-0a7432c8f920>", line 19, in <module>
    model_dir=MODEL_DIR)
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/model_test_1082.py", line 1889, in __init__
    self.keras_model = self.build(mode=mode, config=config)
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/model_test_1082.py", line 2114, in build
    model = ParallelModel(model, config.GPU_COUNT)
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/parallel_model.py", line 37, in __init__
    merged_outputs = self.make_parallel()
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/parallel_model.py", line 62, in make_parallel
    self.inner_model.inputs)}
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/parallel_model.py", line 61, in <dictcomp>
    for name, x in zip(self.inner_model.input_names,
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 1315, in split
    axis=axis, num_split=num_or_size_splits, value=value, name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 7793, in split
    "Split", split_dim=axis, value=value, num_split=num_split, name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3414, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1740, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 32) and num_split 5
	 [[Node: split_4 = Split[T=DT_FLOAT, num_split=5, _device="/job:localhost/replica:0/task:0/device:GPU:0"](tower_0_1/mask_rcnn/mrcnn_mask_deconv/add/y, _arg_input_image_meta_1_0_2/_25023)]]
	 [[Node: tower_4_1/mask_rcnn/mrcnn_detection/map_5/while/LoopCond/_34303 = _HostRecv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:4", send_device_incarnation=1, tensor_name="edge_39553_tower_4_1/mask_rcnn/mrcnn_detection/map_5/while/LoopCond", tensor_type=DT_BOOL, _device="/job:localhost/replica:0/task:0/device:CPU:0"](^_clooptower_4_1/mask_rcnn/mrcnn_detection/map_5/while/TensorArrayReadV3/_24766)]]

And when I tested on 8 GPUs, it gave me another error:

# set gpu_count=8
InvalidArgumentErrorTraceback (most recent call last)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1321     try:
-> 1322       return fn(*args)
   1323     except errors.OpError as e:

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1306       return self._call_tf_sessionrun(
-> 1307           options, feed_dict, fetch_list, target_list, run_metadata)
   1308 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
   1408           self._session, options, feed_dict, fetch_list, target_list,
-> 1409           run_metadata)
   1410     else:

InvalidArgumentError: slice index 4 of dimension 0 out of bounds.
	 [[Node: tower_0/mask_rcnn/ROI/strided_slice_42 = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1, _device="/job:localhost/replica:0/task:0/device:GPU:0"](split_2, tower_0/mask_rcnn/mrcnn_detection/strided_slice_77/stack_1, tower_0/mask_rcnn/mrcnn_detection/strided_slice_100/stack_1, tower_0/mask_rcnn/mrcnn_mask_deconv/strided_slice/stack_1)]]
	 [[Node: res4c_branch2a/kernel/read/_11935 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:5", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_15698_res4c_branch2a/kernel/read", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:5"]()]]

During handling of the above exception, another exception occurred:

InvalidArgumentErrorTraceback (most recent call last)
<ipython-input-19-94f99786ace2> in <module>()
     33     # Run object detection
     34     # it can only accept input with the size equal to INFERENCE_BATCH_SIZE
---> 35     results = model.detect(images, verbose=0)
     36     # delete the result from dummy images
     37     results = results[:cur_batch_size]

/notebooks/Lorenzo/Mask_RCNN/mrcnn/model_test_1082.py in detect(self, images, verbose)
   2578         # Run object detection
   2579         detections, _, _, mrcnn_mask, _, _, _ =\
-> 2580             self.keras_model.predict([molded_images, image_metas, anchors], verbose=0)
   2581         # Process detections
   2582         results = []

/usr/local/lib/python3.5/dist-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
   1798         f = self.predict_function
   1799         return self._predict_loop(f, ins, batch_size=batch_size,
-> 1800                                   verbose=verbose, steps=steps)
   1801 
   1802     def train_on_batch(self, x, y,

/usr/local/lib/python3.5/dist-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose, steps)
   1299                     ins_batch[i] = ins_batch[i].toarray()
   1300 
-> 1301                 batch_outs = f(ins_batch)
   1302                 if not isinstance(batch_outs, list):
   1303                     batch_outs = [batch_outs]

/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2473         session = get_session()
   2474         updated = session.run(fetches=fetches, feed_dict=feed_dict,
-> 2475                               **self.session_kwargs)
   2476         return updated[:len(self.outputs)]
   2477 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    898     try:
    899       result = self._run(None, fetches, feed_dict, options_ptr,
--> 900                          run_metadata_ptr)
    901       if run_metadata:
    902         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1133     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1134       results = self._do_run(handle, final_targets, final_fetches,
-> 1135                              feed_dict_tensor, options, run_metadata)
   1136     else:
   1137       results = []

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1314     if handle is None:
   1315       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1316                            run_metadata)
   1317     else:
   1318       return self._do_call(_prun_fn, handle, feeds, fetches)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1333         except KeyError:
   1334           pass
-> 1335       raise type(e)(node_def, op, message)
   1336 
   1337   def _extend_graph(self):

InvalidArgumentError: slice index 4 of dimension 0 out of bounds.
	 [[Node: tower_0/mask_rcnn/ROI/strided_slice_42 = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1, _device="/job:localhost/replica:0/task:0/device:GPU:0"](split_2, tower_0/mask_rcnn/mrcnn_detection/strided_slice_77/stack_1, tower_0/mask_rcnn/mrcnn_detection/strided_slice_100/stack_1, tower_0/mask_rcnn/mrcnn_mask_deconv/strided_slice/stack_1)]]
	 [[Node: res4c_branch2a/kernel/read/_11935 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:5", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_15698_res4c_branch2a/kernel/read", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:5"]()]]

Caused by op 'tower_0/mask_rcnn/ROI/strided_slice_42', defined at:
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.5/dist-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelapp.py", line 499, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.5/dist-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 345, in run_forever
    self._run_once()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 1312, in _run_once
    handle._run()
  File "/usr/lib/python3.5/asyncio/events.py", line 125, in _run
    self._callback(*self._args)
  File "/usr/local/lib/python3.5/dist-packages/tornado/ioloop.py", line 758, in _run_callback
    ret = callback()
  File "/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 1233, in inner
    self.run()
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 1147, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 346, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 259, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 513, in execute_request
    user_expressions, allow_stdin,
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/ipkernel.py", line 294, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2901, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2961, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-18-75d4a9eceb01>", line 19, in <module>
    model_dir=MODEL_DIR)
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/model_test_1082.py", line 1889, in __init__
    self.keras_model = self.build(mode=mode, config=config)
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/model_test_1082.py", line 2114, in build
    model = ParallelModel(model, config.GPU_COUNT)
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/parallel_model.py", line 37, in __init__
    merged_outputs = self.make_parallel()
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/parallel_model.py", line 83, in make_parallel
    outputs = self.inner_model(inputs)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 617, in __call__
    output = self.call(inputs, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 2078, in call
    output_tensors, _, _ = self.run_internal_graph(inputs, masks)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 2240, in run_internal_graph
    output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/model_test_1082.py", line 309, in call
    names=["pre_nms_anchors"])
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/utils.py", line 829, in batch_slice
    inputs_slice = [x[i] for x in inputs]
  File "/notebooks/Lorenzo/Mask_RCNN/mrcnn/utils.py", line 829, in <listcomp>
    inputs_slice = [x[i] for x in inputs]
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 523, in _slice_helper
    name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 689, in strided_slice
    shrink_axis_mask=shrink_axis_mask)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 8232, in strided_slice
    name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3414, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1740, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): slice index 4 of dimension 0 out of bounds.
	 [[Node: tower_0/mask_rcnn/ROI/strided_slice_42 = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1, _device="/job:localhost/replica:0/task:0/device:GPU:0"](split_2, tower_0/mask_rcnn/mrcnn_detection/strided_slice_77/stack_1, tower_0/mask_rcnn/mrcnn_detection/strided_slice_100/stack_1, tower_0/mask_rcnn/mrcnn_mask_deconv/strided_slice/stack_1)]]
	 [[Node: res4c_branch2a/kernel/read/_11935 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:5", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_15698_res4c_branch2a/kernel/read", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:5"]()]]

Any ideas?

@lixindou2018
Copy link
Copy Markdown

@keineahnung2345 @kawa23 ,I tried to use multiple GPU to test Mask RCNN recently, but the speed did not improve. Could you post your code? thank you

@tianyu-tristan
Copy link
Copy Markdown

tianyu-tristan commented Feb 11, 2019

@kawa23 Thanks for the pull request. I'm reviewing it and trying to trace the issue. A quick question about the error you reported in #1044: How come your tensor shape is 24000? I'd expect it to be 2400 since your batch size is 4.

BATCH_SIZE * DETECTION_MAX_INSTANCES * 6

Did you happen to change DETECTION_MAX_INSTANCES from the default value of 100?

I acknowledge the error you got, but I'm suspecting that the issue might be somewhere else.

@waleedka Hi, thank you for ack on the error. I'm having the same issue (2 GPU, 2 images/GPU, no change of config so I'm 2400 as you stated here, instead of 24000). What's the proper fix to this?

@vikiQiu
Copy link
Copy Markdown

vikiQiu commented Jul 8, 2019

@kawa23 Thanks for the pull request. I'm reviewing it and trying to trace the issue. A quick question about the error you reported in #1044: How come your tensor shape is 24000? I'd expect it to be 2400 since your batch size is 4.

BATCH_SIZE * DETECTION_MAX_INSTANCES * 6

Did you happen to change DETECTION_MAX_INSTANCES from the default value of 100?

I acknowledge the error you got, but I'm suspecting that the issue might be somewhere else.

I get the similar problem as the up:
Predict problem on multi gpu (Input to reshape is a tensor with 600 values, but the requested has 1200)
I set DETECTION_MAX_INSTANCES as default 100, IMAGES_PER_GPU=2 and GPU_COUNT=1.
The problem is solved when I modify the code. It's quite helpful. I think it should be merge into master.

Copy link
Copy Markdown

@phongvu99 phongvu99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Work like a charm!

Copy link
Copy Markdown

@Aashish-Gautam Aashish-Gautam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fixed my issue. Thanks for the help.
Not sure why this change is already not been done in model.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

8 participants