| // Copyright 2014 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/cast_channel/cast_transport.h" |
| |
| #include <stddef.h> |
| #include <stdint.h> |
| #include <utility> |
| |
| #include "base/bind.h" |
| #include "base/callback_helpers.h" |
| #include "base/containers/queue.h" |
| #include "base/macros.h" |
| #include "base/memory/ptr_util.h" |
| #include "base/run_loop.h" |
| #include "base/test/simple_test_clock.h" |
| #include "base/test/task_environment.h" |
| #include "components/cast_channel/cast_framer.h" |
| #include "components/cast_channel/cast_socket.h" |
| #include "components/cast_channel/cast_test_util.h" |
| #include "components/cast_channel/logger.h" |
| #include "net/base/completion_once_callback.h" |
| #include "net/base/completion_repeating_callback.h" |
| #include "net/base/net_errors.h" |
| #include "net/log/test_net_log.h" |
| #include "net/socket/socket.h" |
| #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" |
| #include "services/network/network_context.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| #include "third_party/openscreen/src/cast/common/channel/proto/cast_channel.pb.h" |
| |
| using testing::_; |
| using testing::DoAll; |
| using testing::InSequence; |
| using testing::Invoke; |
| using testing::NotNull; |
| using testing::Return; |
| using testing::WithArg; |
| |
| namespace cast_channel { |
| namespace { |
| |
| const int kChannelId = 0; |
| |
| // Mockable placeholder for write completion events. |
| class CompleteHandler { |
| public: |
| CompleteHandler() {} |
| MOCK_METHOD1(Complete, void(int result)); |
| |
| private: |
| DISALLOW_COPY_AND_ASSIGN(CompleteHandler); |
| }; |
| |
| // Creates a CastMessage proto with the bare minimum required fields set. |
| CastMessage CreateCastMessage() { |
| CastMessage output; |
| output.set_protocol_version(CastMessage::CASTV2_1_0); |
| output.set_namespace_("x"); |
| output.set_source_id("source"); |
| output.set_destination_id("destination"); |
| output.set_payload_type(CastMessage::STRING); |
| output.set_payload_utf8("payload"); |
| return output; |
| } |
| |
| // FIFO queue of completion callbacks. Outstanding write operations are |
| // Push()ed into the queue. Callback completion is simulated by invoking |
| // Pop() in the same order as Push(). |
| class CompletionQueue { |
| public: |
| CompletionQueue() {} |
| ~CompletionQueue() { CHECK_EQ(0u, cb_queue_.size()); } |
| |
| // Enqueues a pending completion callback. |
| void Push(net::CompletionOnceCallback cb) { cb_queue_.push(std::move(cb)); } |
| // Runs the next callback and removes it from the queue. |
| void Pop(int rv) { |
| CHECK_GT(cb_queue_.size(), 0u); |
| std::move(cb_queue_.front()).Run(rv); |
| cb_queue_.pop(); |
| } |
| |
| private: |
| base::queue<net::CompletionOnceCallback> cb_queue_; |
| DISALLOW_COPY_AND_ASSIGN(CompletionQueue); |
| }; |
| |
| // GMock action that reads data from an IOBuffer and writes it to a string |
| // variable. |
| // |
| // buf_idx (template parameter 0): 0-based index of the net::IOBuffer |
| // in the function mock arg list. |
| // size_idx (template parameter 1): 0-based index of the byte count arg. |
| // str: pointer to the string which will receive data from the buffer. |
| ACTION_TEMPLATE(ReadBufferToString, |
| HAS_2_TEMPLATE_PARAMS(int, buf_idx, int, size_idx), |
| AND_1_VALUE_PARAMS(str)) { |
| str->assign(testing::get<buf_idx>(args)->data(), |
| testing::get<size_idx>(args)); |
| } |
| |
| // GMock action that writes data from a string to an IOBuffer. |
| // |
| // buf_idx (template parameter 0): 0-based index of the IOBuffer arg. |
| // str: the string containing data to be written to the IOBuffer. |
| ACTION_TEMPLATE(FillBufferFromString, |
| HAS_1_TEMPLATE_PARAMS(int, buf_idx), |
| AND_1_VALUE_PARAMS(str)) { |
| memcpy(testing::get<buf_idx>(args)->data(), str.data(), str.size()); |
| } |
| |
| // GMock action that enqueues a write completion callback in a queue. |
| // |
| // buf_idx (template parameter 0): 0-based index of the CompletionCallback. |
| // completion_queue: a pointer to the CompletionQueue. |
| ACTION_TEMPLATE(EnqueueCallback, |
| HAS_1_TEMPLATE_PARAMS(int, cb_idx), |
| AND_1_VALUE_PARAMS(completion_queue)) { |
| completion_queue->Push(testing::get<cb_idx>(args)); |
| } |
| |
| } // namespace |
| |
| class MockSocket : public cast_channel::CastTransportImpl::Channel { |
| public: |
| void Read(net::IOBuffer* buffer, |
| int bytes, |
| net::CompletionOnceCallback callback) override { |
| Read(buffer, bytes, base::AdaptCallbackForRepeating(std::move(callback))); |
| } |
| |
| void Write(net::IOBuffer* buffer, |
| int bytes, |
| net::CompletionOnceCallback callback) override { |
| Write(buffer, bytes, base::AdaptCallbackForRepeating(std::move(callback))); |
| } |
| |
| MOCK_METHOD3(Read, |
| void(net::IOBuffer* buf, |
| int buf_len, |
| const net::CompletionRepeatingCallback& callback)); |
| |
| MOCK_METHOD3(Write, |
| void(net::IOBuffer* buf, |
| int buf_len, |
| const net::CompletionRepeatingCallback& callback)); |
| }; |
| |
| class CastTransportTest : public testing::Test { |
| public: |
| CastTransportTest() : logger_(new Logger()) { |
| delegate_ = new MockCastTransportDelegate; |
| transport_.reset(new CastTransportImpl(&mock_socket_, kChannelId, |
| CreateIPEndPointForTest(), logger_)); |
| transport_->SetReadDelegate(base::WrapUnique(delegate_)); |
| } |
| ~CastTransportTest() override {} |
| |
| protected: |
| // Runs all pending tasks in the message loop. |
| void RunPendingTasks() { |
| base::RunLoop run_loop; |
| run_loop.RunUntilIdle(); |
| } |
| |
| base::test::SingleThreadTaskEnvironment task_environment_; |
| MockCastTransportDelegate* delegate_; |
| MockSocket mock_socket_; |
| Logger* logger_; |
| std::unique_ptr<CastTransport> transport_; |
| }; |
| |
| // ---------------------------------------------------------------------------- |
| // Asynchronous write tests |
| TEST_F(CastTransportTest, TestFullWriteAsync) { |
| CompletionQueue socket_cbs; |
| CompleteHandler write_handler; |
| std::string output; |
| |
| CastMessage message = CreateCastMessage(); |
| std::string serialized_message; |
| EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message)); |
| |
| EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size(), _)) |
| .WillOnce(DoAll(ReadBufferToString<0, 1>(&output), |
| EnqueueCallback<2>(&socket_cbs))); |
| EXPECT_CALL(write_handler, Complete(net::OK)); |
| transport_->SendMessage(message, |
| base::BindOnce(&CompleteHandler::Complete, |
| base::Unretained(&write_handler))); |
| RunPendingTasks(); |
| socket_cbs.Pop(serialized_message.size()); |
| RunPendingTasks(); |
| EXPECT_EQ(serialized_message, output); |
| } |
| |
| TEST_F(CastTransportTest, TestPartialWritesAsync) { |
| InSequence seq; |
| CompletionQueue socket_cbs; |
| CompleteHandler write_handler; |
| std::string output; |
| |
| CastMessage message = CreateCastMessage(); |
| std::string serialized_message; |
| EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message)); |
| |
| // Only one byte is written. |
| EXPECT_CALL(mock_socket_, |
| Write(NotNull(), static_cast<int>(serialized_message.size()), _)) |
| .WillOnce(DoAll(ReadBufferToString<0, 1>(&output), |
| EnqueueCallback<2>(&socket_cbs))); |
| // Remainder of bytes are written. |
| EXPECT_CALL( |
| mock_socket_, |
| Write(NotNull(), static_cast<int>(serialized_message.size() - 1), _)) |
| .WillOnce(DoAll(ReadBufferToString<0, 1>(&output), |
| EnqueueCallback<2>(&socket_cbs))); |
| |
| transport_->SendMessage(message, |
| base::BindOnce(&CompleteHandler::Complete, |
| base::Unretained(&write_handler))); |
| RunPendingTasks(); |
| EXPECT_EQ(serialized_message, output); |
| socket_cbs.Pop(1); |
| RunPendingTasks(); |
| |
| EXPECT_CALL(write_handler, Complete(net::OK)); |
| socket_cbs.Pop(serialized_message.size() - 1); |
| RunPendingTasks(); |
| EXPECT_EQ(serialized_message.substr(1, serialized_message.size() - 1), |
| output); |
| } |
| |
| TEST_F(CastTransportTest, TestWriteFailureAsync) { |
| CompletionQueue socket_cbs; |
| CompleteHandler write_handler; |
| CastMessage message = CreateCastMessage(); |
| EXPECT_CALL(mock_socket_, Write(NotNull(), _, _)) |
| .WillOnce(EnqueueCallback<2>(&socket_cbs)); |
| EXPECT_CALL(write_handler, Complete(net::ERR_FAILED)); |
| EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR)); |
| transport_->SendMessage(message, |
| base::BindOnce(&CompleteHandler::Complete, |
| base::Unretained(&write_handler))); |
| RunPendingTasks(); |
| socket_cbs.Pop(net::ERR_CONNECTION_RESET); |
| RunPendingTasks(); |
| EXPECT_EQ(ChannelEvent::SOCKET_WRITE, |
| logger_->GetLastError(kChannelId).channel_event); |
| EXPECT_EQ(net::ERR_CONNECTION_RESET, |
| logger_->GetLastError(kChannelId).net_return_value); |
| } |
| |
| // ---------------------------------------------------------------------------- |
| // Asynchronous read tests |
| TEST_F(CastTransportTest, TestFullReadAsync) { |
| InSequence s; |
| CompletionQueue socket_cbs; |
| |
| CastMessage message = CreateCastMessage(); |
| std::string serialized_message; |
| EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message)); |
| EXPECT_CALL(*delegate_, Start()); |
| |
| // Read bytes [0, 3]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), MessageFramer::MessageHeader::header_size(), _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message), |
| EnqueueCallback<2>(&socket_cbs))); |
| |
| // Read bytes [4, n]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size(), |
| _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr( |
| MessageFramer::MessageHeader::header_size(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size())), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| |
| EXPECT_CALL(*delegate_, OnMessage(EqualsProto(message))); |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), MessageFramer::MessageHeader::header_size(), _)); |
| transport_->Start(); |
| RunPendingTasks(); |
| socket_cbs.Pop(MessageFramer::MessageHeader::header_size()); |
| socket_cbs.Pop(serialized_message.size() - |
| MessageFramer::MessageHeader::header_size()); |
| RunPendingTasks(); |
| } |
| |
| TEST_F(CastTransportTest, TestPartialReadAsync) { |
| InSequence s; |
| CompletionQueue socket_cbs; |
| |
| CastMessage message = CreateCastMessage(); |
| std::string serialized_message; |
| EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message)); |
| |
| EXPECT_CALL(*delegate_, Start()); |
| |
| // Read bytes [0, 3]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), MessageFramer::MessageHeader::header_size(), _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| // Read bytes [4, n-1]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size(), |
| _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr( |
| MessageFramer::MessageHeader::header_size(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size() - 1)), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| // Read final byte. |
| EXPECT_CALL(mock_socket_, Read(NotNull(), 1, _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr( |
| serialized_message.size() - 1, 1)), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| EXPECT_CALL(*delegate_, OnMessage(EqualsProto(message))); |
| transport_->Start(); |
| socket_cbs.Pop(MessageFramer::MessageHeader::header_size()); |
| socket_cbs.Pop(serialized_message.size() - |
| MessageFramer::MessageHeader::header_size() - 1); |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), MessageFramer::MessageHeader::header_size(), _)); |
| socket_cbs.Pop(1); |
| } |
| |
| TEST_F(CastTransportTest, TestReadErrorInHeaderAsync) { |
| CompletionQueue socket_cbs; |
| |
| CastMessage message = CreateCastMessage(); |
| std::string serialized_message; |
| EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message)); |
| |
| EXPECT_CALL(*delegate_, Start()); |
| |
| // Read bytes [0, 3]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), MessageFramer::MessageHeader::header_size(), _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| |
| EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR)); |
| transport_->Start(); |
| // Header read failure. |
| socket_cbs.Pop(net::ERR_CONNECTION_RESET); |
| EXPECT_EQ(ChannelEvent::SOCKET_READ, |
| logger_->GetLastError(kChannelId).channel_event); |
| EXPECT_EQ(net::ERR_CONNECTION_RESET, |
| logger_->GetLastError(kChannelId).net_return_value); |
| } |
| |
| TEST_F(CastTransportTest, TestReadErrorInBodyAsync) { |
| CompletionQueue socket_cbs; |
| |
| CastMessage message = CreateCastMessage(); |
| std::string serialized_message; |
| EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message)); |
| |
| EXPECT_CALL(*delegate_, Start()); |
| |
| // Read bytes [0, 3]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), MessageFramer::MessageHeader::header_size(), _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| // Read bytes [4, n-1]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size(), |
| _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr( |
| MessageFramer::MessageHeader::header_size(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size() - 1)), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR)); |
| transport_->Start(); |
| // Header read is OK. |
| socket_cbs.Pop(MessageFramer::MessageHeader::header_size()); |
| // Body read fails. |
| socket_cbs.Pop(net::ERR_CONNECTION_RESET); |
| EXPECT_EQ(ChannelEvent::SOCKET_READ, |
| logger_->GetLastError(kChannelId).channel_event); |
| EXPECT_EQ(net::ERR_CONNECTION_RESET, |
| logger_->GetLastError(kChannelId).net_return_value); |
| } |
| |
| TEST_F(CastTransportTest, TestReadCorruptedMessageAsync) { |
| CompletionQueue socket_cbs; |
| |
| CastMessage message = CreateCastMessage(); |
| std::string serialized_message; |
| EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message)); |
| |
| // Corrupt the serialized message body(set it to X's). |
| for (size_t i = MessageFramer::MessageHeader::header_size(); |
| i < serialized_message.size(); ++i) { |
| serialized_message[i] = 'x'; |
| } |
| |
| EXPECT_CALL(*delegate_, Start()); |
| // Read bytes [0, 3]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), MessageFramer::MessageHeader::header_size(), _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| // Read bytes [4, n]. |
| EXPECT_CALL(mock_socket_, |
| Read(NotNull(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size(), |
| _)) |
| .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr( |
| MessageFramer::MessageHeader::header_size(), |
| serialized_message.size() - |
| MessageFramer::MessageHeader::header_size() - 1)), |
| EnqueueCallback<2>(&socket_cbs))) |
| .RetiresOnSaturation(); |
| |
| EXPECT_CALL(*delegate_, OnError(ChannelError::INVALID_MESSAGE)); |
| transport_->Start(); |
| socket_cbs.Pop(MessageFramer::MessageHeader::header_size()); |
| socket_cbs.Pop(serialized_message.size() - |
| MessageFramer::MessageHeader::header_size()); |
| } |
| |
| } // namespace cast_channel |