164 lines
4.6 KiB
C++
164 lines
4.6 KiB
C++
#include <uds/client_channel.h>
|
|
|
|
#include <sys/socket.h>
|
|
|
|
#include <algorithm>
|
|
#include <limits>
|
|
#include <random>
|
|
#include <thread>
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <pdx/client.h>
|
|
#include <pdx/rpc/remote_method.h>
|
|
#include <pdx/service.h>
|
|
#include <pdx/service_dispatcher.h>
|
|
|
|
#include <uds/client_channel_factory.h>
|
|
#include <uds/service_endpoint.h>
|
|
|
|
using testing::Return;
|
|
using testing::_;
|
|
|
|
using android::pdx::ClientBase;
|
|
using android::pdx::LocalChannelHandle;
|
|
using android::pdx::LocalHandle;
|
|
using android::pdx::Message;
|
|
using android::pdx::ServiceBase;
|
|
using android::pdx::ServiceDispatcher;
|
|
using android::pdx::Status;
|
|
using android::pdx::rpc::DispatchRemoteMethod;
|
|
using android::pdx::uds::ClientChannel;
|
|
using android::pdx::uds::ClientChannelFactory;
|
|
using android::pdx::uds::Endpoint;
|
|
|
|
namespace {
|
|
|
|
struct TestProtocol {
|
|
using DataType = int8_t;
|
|
enum {
|
|
kOpSum = 0,
|
|
};
|
|
PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
|
|
};
|
|
|
|
class TestService : public ServiceBase<TestService> {
|
|
public:
|
|
explicit TestService(std::unique_ptr<Endpoint> endpoint)
|
|
: ServiceBase{"TestService", std::move(endpoint)} {}
|
|
|
|
Status<void> HandleMessage(Message& message) override {
|
|
switch (message.GetOp()) {
|
|
case TestProtocol::kOpSum:
|
|
DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
|
|
message);
|
|
return {};
|
|
|
|
default:
|
|
return Service::HandleMessage(message);
|
|
}
|
|
}
|
|
|
|
int64_t OnSum(Message& /*message*/,
|
|
const std::vector<TestProtocol::DataType>& data) {
|
|
return std::accumulate(data.begin(), data.end(), int64_t{0});
|
|
}
|
|
};
|
|
|
|
class TestClient : public ClientBase<TestClient> {
|
|
public:
|
|
using ClientBase::ClientBase;
|
|
|
|
int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
|
|
auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
|
|
return status ? status.get() : -1;
|
|
}
|
|
};
|
|
|
|
class TestServiceRunner {
|
|
public:
|
|
explicit TestServiceRunner(LocalHandle channel_socket) {
|
|
auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
|
|
endpoint->RegisterNewChannelForTests(std::move(channel_socket));
|
|
service_ = TestService::Create(std::move(endpoint));
|
|
dispatcher_ = ServiceDispatcher::Create();
|
|
dispatcher_->AddService(service_);
|
|
dispatch_thread_ = std::thread(
|
|
std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
|
|
}
|
|
|
|
~TestServiceRunner() {
|
|
dispatcher_->SetCanceled(true);
|
|
dispatch_thread_.join();
|
|
dispatcher_->RemoveService(service_);
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<TestService> service_;
|
|
std::unique_ptr<ServiceDispatcher> dispatcher_;
|
|
std::thread dispatch_thread_;
|
|
};
|
|
|
|
class ClientChannelTest : public testing::Test {
|
|
public:
|
|
void SetUp() override {
|
|
int channel_sockets[2] = {};
|
|
ASSERT_EQ(
|
|
0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
|
|
LocalHandle service_channel{channel_sockets[0]};
|
|
LocalHandle client_channel{channel_sockets[1]};
|
|
|
|
service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
|
|
auto factory = ClientChannelFactory::Create(std::move(client_channel));
|
|
auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
|
|
ASSERT_TRUE(status);
|
|
client_ = TestClient::Create(status.take());
|
|
}
|
|
|
|
void TearDown() override {
|
|
service_runner_.reset();
|
|
client_.reset();
|
|
}
|
|
|
|
protected:
|
|
std::unique_ptr<TestServiceRunner> service_runner_;
|
|
std::shared_ptr<TestClient> client_;
|
|
};
|
|
|
|
TEST_F(ClientChannelTest, MultithreadedClient) {
|
|
constexpr int kNumTestThreads = 8;
|
|
constexpr size_t kDataSize = 1000; // Try to keep RPC buffer size below 4K.
|
|
|
|
std::random_device rd;
|
|
std::mt19937 gen{rd()};
|
|
std::uniform_int_distribution<TestProtocol::DataType> dist{
|
|
std::numeric_limits<TestProtocol::DataType>::min(),
|
|
std::numeric_limits<TestProtocol::DataType>::max()};
|
|
|
|
auto worker = [](std::shared_ptr<TestClient> client,
|
|
std::vector<TestProtocol::DataType> data) {
|
|
constexpr int kMaxIterations = 500;
|
|
int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
|
|
for (int i = 0; i < kMaxIterations; i++) {
|
|
ASSERT_EQ(expected, client->Sum(data));
|
|
}
|
|
};
|
|
|
|
// Start client threads.
|
|
std::vector<TestProtocol::DataType> data;
|
|
data.resize(kDataSize);
|
|
std::vector<std::thread> threads;
|
|
for (int i = 0; i < kNumTestThreads; i++) {
|
|
std::generate(data.begin(), data.end(),
|
|
[&dist, &gen]() { return dist(gen); });
|
|
threads.emplace_back(worker, client_, data);
|
|
}
|
|
|
|
// Wait for threads to finish.
|
|
for (auto& thread : threads)
|
|
thread.join();
|
|
}
|
|
|
|
} // namespace
|