Skip to content

Commit

Permalink
[MigraphX] Fix potential synchronization problem when ORT_ENABLE_STRE…
Browse files Browse the repository at this point in the history
…AM is true (#22589)

### Description
Replace `hipMemcpy` with `hipMemcpyWithStream`



### Motivation and Context
`hipMemcpy` uses default stream, which may be out of synchronization
with the current stream when ORT_ENABLE_STREAM is defined.
  • Loading branch information
xinyazhang authored Oct 25, 2024
1 parent 7acbd51 commit 28efacf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/migraphx/gpu_data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst,
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast<hipStream_t>(stream.GetHandle())));
} else {
// copy from other CPU memory to GPU, this is blocking
HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice));
HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
}
} else if (src_device.Type() == OrtDevice::GPU) {
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast<hipStream_t>(stream.GetHandle())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
std::vector<int64_t> ort_shape{res_lens.begin(), res_lens.end()};
auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size());
void* output_data = output_tensor.GetTensorMutableRawData();
HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice));
HIP_CALL_THROW(hipMemcpyWithStream(output_data,
gpu_res.data(),
res_shape.bytes(),
hipMemcpyDeviceToDevice,
static_cast<hipStream_t>(rocm_stream)));
}
}
};
Expand Down

0 comments on commit 28efacf

Please sign in to comment.