// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc.  All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd

// Author: kenton@google.com (Kenton Varda)
//  Based on original Protocol Buffers design by
//  Sanjay Ghemawat, Jeff Dean, and others.
//
// Since the reflection interface for DynamicMessage is implemented by
// GenericMessageReflection, the only thing we really have to test is
// that DynamicMessage correctly sets up the information that
// GenericMessageReflection needs to use.  So, we focus on that in this
// test.  Other tests, such as generic_message_reflection_unittest and
// reflection_ops_unittest, cover the rest of the functionality used by
// DynamicMessage.

#include "google/protobuf/dynamic_message.h"

#include <cstddef>
#include <memory>
#include <string>
#include <tuple>
#include <vector>

#include "google/protobuf/descriptor.pb.h"
#include <gtest/gtest.h>
#include "absl/log/absl_check.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/cpp_features.pb.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/edition_unittest.pb.h"
#include "google/protobuf/generated_message_tctable_gen.h"
#include "google/protobuf/generated_message_tctable_impl.h"
#include "google/protobuf/port.h"
#include "google/protobuf/test_util.h"
#include "google/protobuf/unittest.pb.h"
#include "google/protobuf/unittest_import.pb.h"
#include "google/protobuf/unittest_import_public.pb.h"
#include "google/protobuf/unittest_no_field_presence.pb.h"


#include "google/protobuf/port_def.inc"

namespace google {
namespace protobuf {
namespace {

void AddUnittestDescriptors(
    DescriptorPool& pool, std::vector<const FileDescriptor*>* files = nullptr) {
  const auto add = [&](auto* descriptor) {
    FileDescriptorProto file;
    descriptor->file()->CopyTo(&file);
    ASSERT_TRUE(pool.BuildFile(file) != nullptr);
    if (files) {
      files->push_back(pool.FindFileByName(file.name()));
    }
  };
  // We want to make sure that DynamicMessage works (particularly with
  // extensions) even if we use descriptors that are *not* from compiled-in
  // types, so we make copies of the descriptors for unittest.proto and
  // unittest_import.proto.
  add(unittest_import::PublicImportMessage::descriptor());
  add(unittest_import::ImportMessage::descriptor());
  add(unittest::TestAllTypes::descriptor());
  add(proto2_nofieldpresence_unittest::TestAllTypes::descriptor());

  add(google::protobuf::DescriptorProto::descriptor());
  add(pb::CppFeatures::descriptor());
  add(edition_unittest::TestAllTypes::descriptor());
}

TEST(DynamicMessageTest,
     MicroStringFieldsWithDefaultValuesDontCopyTheDefaultOnCreate) {
  if (!internal::EnableExperimentalMicroString()) {
    GTEST_SKIP() << "MicroString is not enabled.";
  }
  DynamicMessageFactory factory;
  const auto& prototype =
      factory.GetPrototype(edition_unittest::TestAllTypes::descriptor());
  std::unique_ptr<Message> msg(prototype->New()), msg2(prototype->New());

  const auto* field =
      edition_unittest::TestAllTypes::descriptor()->FindFieldByName(
          "default_string");

  Reflection::ScratchSpace scratch;
  absl::string_view str =
      msg->GetReflection()->GetStringView(*msg, field, scratch);
  absl::string_view str2 =
      msg2->GetReflection()->GetStringView(*msg2, field, scratch);

  EXPECT_EQ(str, "hello");
  EXPECT_EQ(str2, "hello");
  // But also both point to the exact same buffer.
  EXPECT_EQ(static_cast<const void*>(str.data()),
            static_cast<const void*>(str2.data()));
}

TEST(DynamicMessageTest,
     MicroStringFieldsWithDefaultValuesResetProperlyOnClear) {
  if (!internal::EnableExperimentalMicroString()) {
    GTEST_SKIP() << "MicroString is not enabled.";
  }
  DynamicMessageFactory factory;
  const auto& prototype =
      factory.GetPrototype(edition_unittest::TestAllTypes::descriptor());
  std::unique_ptr<Message> msg(prototype->New());

  const auto* field =
      edition_unittest::TestAllTypes::descriptor()->FindFieldByName(
          "default_string");
  const auto* ref = msg->GetReflection();

  Reflection::ScratchSpace scratch;
  absl::string_view default_value = ref->GetStringView(*msg, field, scratch);
  ref->SetString(msg.get(), field, std::string("foo"));
  EXPECT_EQ("foo", ref->GetStringView(*msg, field, scratch));
  msg->Clear();
  EXPECT_EQ(default_value, ref->GetStringView(*msg, field, scratch));
  // But also points to the original default buffer.
  EXPECT_EQ(
      static_cast<const void*>(ref->GetStringView(*msg, field, scratch).data()),
      static_cast<const void*>(default_value.data()));
}

struct OverflowTestCase {
  DescriptorPool pool;
  DynamicMessageFactory factory{&pool};
  const Message* prototype;
  const internal::TcParseTableBase* table;
};

std::unique_ptr<OverflowTestCase> GenerateOverflowTestCase(int num_oneofs) {
  auto test_case = std::make_unique<OverflowTestCase>();

  FileDescriptorProto file_proto;
  file_proto.set_name("foo.proto");
  file_proto.set_edition(Edition::EDITION_2024);
  auto* desc_proto = file_proto.add_message_type();
  desc_proto->set_name("Foo");

  const auto add_field = [&](auto name, int number, auto type) {
    auto* field = desc_proto->add_field();
    field->set_name(name);
    field->set_number(number);
    field->set_type(type);
    return field;
  };

  // Overflow 16-bit integer worth of data.
  for (int i = 0; i < num_oneofs; ++i) {
    desc_proto->add_oneof_decl()->set_name(absl::StrCat("oneof_", i));
    add_field(absl::StrCat("p", i), 1'000'000 + i,
              FieldDescriptorProto::TYPE_STRING)
        ->set_oneof_index(i);
  }

  add_field("first_field", 1, FieldDescriptorProto::TYPE_INT64);
  add_field("second_field", 2, FieldDescriptorProto::TYPE_FIXED64);

  auto* file = test_case->pool.BuildFile(file_proto);
  ABSL_CHECK(file);
  auto* desc = file->message_type(0);
  test_case->prototype = test_case->factory.GetPrototype(desc);

  struct Robber : MessageLite {
    using MessageLite::GetTcParseTable;
  };

  test_case->table = (test_case->prototype->*&Robber::GetTcParseTable)();

  // Verify that we have the fields.
  {
    auto field = test_case->table->field_entries()[0];
    // It is the field we are looking for.
    ABSL_CHECK(absl::StrContains(internal::TypeCardToString(field.type_card),
                                 "kInt64"));
    // But the hasbit is small.
    ABSL_CHECK_LT(field.has_idx - test_case->table->has_bits_offset * 8,
                  internal::TailCallTableInfo::kMaxFastFieldHasbitIndex);
  }
  {
    auto field = test_case->table->field_entries()[1];
    // It is the field we are looking for.
    ABSL_CHECK(absl::StrContains(internal::TypeCardToString(field.type_card),
                                 "kFixed64"));
    // But the hasbit is small.
    ABSL_CHECK_LT(field.has_idx - test_case->table->has_bits_offset * 8,
                  internal::TailCallTableInfo::kMaxFastFieldHasbitIndex);
  }

  return test_case;
}

std::unique_ptr<OverflowTestCase> FindOverflowTestCase() {
  int low = 1, hi = std::numeric_limits<uint16_t>::max() / sizeof(uint32_t);

  while (true) {
    int mid = (low + hi) / 2;
    ABSL_CHECK_NE(mid, low) << "Bad initial bounds.";
    ABSL_CHECK_NE(mid, hi) << "Bad initial bounds.";
    auto test_case = GenerateOverflowTestCase(mid);
    ABSL_LOG(INFO) << "FindOverflowTestCase: low=" << low << " mid=" << mid
                   << " hi=" << hi << " first_offset="
                   << test_case->table->field_entries()[0].offset
                   << " second_offset="
                   << test_case->table->field_entries()[1].offset;
    if (test_case->table->field_entries()[0].offset >=
        std::numeric_limits<uint16_t>::max()) {
      // Too much padding.
      hi = mid;
    } else if (test_case->table->field_entries()[1].offset <
               std::numeric_limits<uint16_t>::max()) {
      // Too little padding.
      low = mid;
    } else {
      // Perfect padding
      return test_case;
    }
  }
}

TEST(DynamicMessageTest, IncompatibleFastFieldsAreRejected) {
  auto test_case = FindOverflowTestCase();

  // The fast table should have 2 entries for the two good fields.
  ASSERT_EQ(test_case->table->fast_idx_mask, 8);
  // The one for field 1 (at pos 1) should not be MiniParse.
  EXPECT_NE(
      test_case->table->fast_entry(1)->target(),
      static_cast<internal::TailCallParseFunc>(&internal::TcParser::MiniParse));
  // While the one for field 2 (at pos 0) should have been switched to
  // MiniParse.
  EXPECT_EQ(
      test_case->table->fast_entry(0)->target(),
      static_cast<internal::TailCallParseFunc>(&internal::TcParser::MiniParse));

  // Now verify via parsing.
  std::unique_ptr<Message> msg(test_case->prototype->New());
  auto* ref = msg->GetReflection();
  auto* desc = msg->GetDescriptor();
  constexpr uint64_t value1 = 0x1234567890abcdefu;
  constexpr uint64_t value2 = 0xfedcba0987654321u;
  const auto verify_message = [&] {
    EXPECT_TRUE(ref->HasField(*msg, desc->FindFieldByName("first_field")));
    EXPECT_EQ(ref->GetInt64(*msg, desc->FindFieldByName("first_field")),
              value1);

    EXPECT_TRUE(ref->HasField(*msg, desc->FindFieldByName("second_field")));
    EXPECT_EQ(ref->GetUInt64(*msg, desc->FindFieldByName("second_field")),
              value2);

    for (int i = 0; i < desc->oneof_decl_count(); ++i) {
      EXPECT_FALSE(ref->HasOneof(*msg, desc->oneof_decl(i)))
          << desc->oneof_decl(i)->name();
    }
  };
  ref->SetInt64(&*msg, desc->FindFieldByName("first_field"), value1);
  ref->SetUInt64(&*msg, desc->FindFieldByName("second_field"), value2);
  verify_message();

  const std::string serialized = msg->SerializeAsString();
  ASSERT_TRUE(msg->ParseFromString(serialized));
  ABSL_LOG(INFO) << "After parse.";

  verify_message();
  EXPECT_EQ(serialized, msg->SerializeAsString());
}

class DynamicMessageTest
    : public ::testing::TestWithParam<std::tuple<bool, bool>> {
 protected:
  DescriptorPool pool_;
  DynamicMessageFactory factory_;
  const Descriptor* descriptor_;
  const Message* prototype_;
  const Descriptor* extensions_descriptor_;
  const Message* extensions_prototype_;
  const Descriptor* packed_extensions_descriptor_;
  const Message* packed_extensions_prototype_;
  const Descriptor* packed_descriptor_;
  const Message* packed_prototype_;
  const Descriptor* oneof_descriptor_;
  const Message* oneof_prototype_;
  const Descriptor* proto3_descriptor_;
  const Message* proto3_prototype_;

  DynamicMessageTest() : factory_(&pool_) {}

  void SetUp() override {
    AddUnittestDescriptors(pool_);

    const auto type_name = [&](absl::string_view name) {
      return absl::StrCat(
          use_editions_proto() ? "edition_unittest." : "proto2_unittest.",
          name);
    };

    descriptor_ = pool_.FindMessageTypeByName(type_name("TestAllTypes"));
    ASSERT_TRUE(descriptor_ != nullptr);
    prototype_ = factory_.GetPrototype(descriptor_);

    extensions_descriptor_ =
        pool_.FindMessageTypeByName(type_name("TestAllExtensions"));
    ASSERT_TRUE(extensions_descriptor_ != nullptr);
    extensions_prototype_ = factory_.GetPrototype(extensions_descriptor_);

    packed_extensions_descriptor_ =
        pool_.FindMessageTypeByName(type_name("TestPackedExtensions"));
    ASSERT_TRUE(packed_extensions_descriptor_ != nullptr);
    packed_extensions_prototype_ =
        factory_.GetPrototype(packed_extensions_descriptor_);

    packed_descriptor_ =
        pool_.FindMessageTypeByName(type_name("TestPackedTypes"));
    ASSERT_TRUE(packed_descriptor_ != nullptr);
    packed_prototype_ = factory_.GetPrototype(packed_descriptor_);

    oneof_descriptor_ = pool_.FindMessageTypeByName(type_name("TestOneof2"));
    ASSERT_TRUE(oneof_descriptor_ != nullptr);
    oneof_prototype_ = factory_.GetPrototype(oneof_descriptor_);

    proto3_descriptor_ = pool_.FindMessageTypeByName(
        "proto2_nofieldpresence_unittest.TestAllTypes");
    ASSERT_TRUE(proto3_descriptor_ != nullptr);
    proto3_prototype_ = factory_.GetPrototype(proto3_descriptor_);
  }

  bool use_arena() const { return std::get<0>(GetParam()); }
  bool use_editions_proto() const { return std::get<1>(GetParam()); }
};

TEST_P(DynamicMessageTest, Descriptor) {
  // Check that the descriptor on the DynamicMessage matches the descriptor
  // passed to GetPrototype().
  EXPECT_EQ(prototype_->GetDescriptor(), descriptor_);
}

TEST_P(DynamicMessageTest, OnePrototype) {
  // Check that requesting the same prototype twice produces the same object.
  EXPECT_EQ(prototype_, factory_.GetPrototype(descriptor_));
}

TEST_P(DynamicMessageTest, Defaults) {
  // Check that all default values are set correctly in the initial message.
  TestUtil::ReflectionTester reflection_tester(descriptor_);
  reflection_tester.ExpectClearViaReflection(*prototype_);
}

TEST_P(DynamicMessageTest, IndependentOffsets) {
  // Check that all fields have independent offsets by setting each
  // one to a unique value then checking that they all still have those
  // unique values (i.e. they don't stomp each other).
  Arena arena;
  Message* message = prototype_->New(use_arena() ? &arena : nullptr);
  TestUtil::ReflectionTester reflection_tester(descriptor_);

  reflection_tester.SetAllFieldsViaReflection(message);
  reflection_tester.ExpectAllFieldsSetViaReflection(*message);

  if (!use_arena()) {
    delete message;
  }
}

TEST_P(DynamicMessageTest, Extensions) {
  // Check that extensions work.
  Arena arena;
  Message* message = extensions_prototype_->New(use_arena() ? &arena : nullptr);
  TestUtil::ReflectionTester reflection_tester(extensions_descriptor_);

  reflection_tester.SetAllFieldsViaReflection(message);
  reflection_tester.ExpectAllFieldsSetViaReflection(*message);

  if (!use_arena()) {
    delete message;
  }
}

TEST_P(DynamicMessageTest, PackedExtensions) {
  // Check that extensions work.
  Arena arena;
  Message* message =
      packed_extensions_prototype_->New(use_arena() ? &arena : nullptr);
  TestUtil::ReflectionTester reflection_tester(packed_extensions_descriptor_);

  reflection_tester.SetPackedFieldsViaReflection(message);
  reflection_tester.ExpectPackedFieldsSetViaReflection(*message);

  if (!use_arena()) {
    delete message;
  }
}

TEST_P(DynamicMessageTest, PackedFields) {
  // Check that packed fields work properly.
  Arena arena;
  Message* message = packed_prototype_->New(use_arena() ? &arena : nullptr);
  TestUtil::ReflectionTester reflection_tester(packed_descriptor_);

  reflection_tester.SetPackedFieldsViaReflection(message);
  reflection_tester.ExpectPackedFieldsSetViaReflection(*message);

  if (!use_arena()) {
    delete message;
  }
}

TEST_P(DynamicMessageTest, Oneof) {
  // Check that oneof fields work properly.
  Arena arena;
  Message* message = oneof_prototype_->New(use_arena() ? &arena : nullptr);

  // Check default values.
  const Descriptor* descriptor = message->GetDescriptor();
  const Reflection* reflection = message->GetReflection();
  EXPECT_EQ(0, reflection->GetInt32(*message,
                                    descriptor->FindFieldByName("foo_int")));
  EXPECT_EQ("", reflection->GetString(
                    *message, descriptor->FindFieldByName("foo_string")));
  EXPECT_EQ("", reflection->GetString(*message,
                                      descriptor->FindFieldByName("foo_cord")));
  EXPECT_EQ("", reflection->GetString(
                    *message, descriptor->FindFieldByName("foo_string_piece")));
  EXPECT_EQ("", reflection->GetString(
                    *message, descriptor->FindFieldByName("foo_bytes")));
  EXPECT_EQ(
      use_editions_proto() ? +edition_unittest::TestOneof2::UNKNOWN
                           : +unittest::TestOneof2::FOO,
      reflection->GetEnum(*message, descriptor->FindFieldByName("foo_enum"))
          ->number());
  const Descriptor* nested_descriptor;
  const Message* nested_prototype;
  nested_descriptor = oneof_descriptor_->FindNestedTypeByName("NestedMessage");
  nested_prototype = factory_.GetPrototype(nested_descriptor);
  EXPECT_EQ(nested_prototype,
            &reflection->GetMessage(
                *message, descriptor->FindFieldByName("foo_message")));
  const Descriptor* foogroup_descriptor;
  const Message* foogroup_prototype;
  foogroup_descriptor = oneof_descriptor_->FindNestedTypeByName("FooGroup");
  foogroup_prototype = factory_.GetPrototype(foogroup_descriptor);
  EXPECT_EQ(foogroup_prototype,
            &reflection->GetMessage(*message,
                                    descriptor->FindFieldByName("foogroup")));
  EXPECT_NE(foogroup_prototype,
            &reflection->GetMessage(
                *message, descriptor->FindFieldByName("foo_lazy_message")));
  EXPECT_EQ(5, reflection->GetInt32(*message,
                                    descriptor->FindFieldByName("bar_int")));
  EXPECT_EQ("STRING", reflection->GetString(
                          *message, descriptor->FindFieldByName("bar_string")));
  EXPECT_EQ("CORD", reflection->GetString(
                        *message, descriptor->FindFieldByName("bar_cord")));
  EXPECT_EQ("SPIECE",
            reflection->GetString(
                *message, descriptor->FindFieldByName("bar_string_piece")));
  EXPECT_EQ("BYTES", reflection->GetString(
                         *message, descriptor->FindFieldByName("bar_bytes")));
  EXPECT_EQ(
      unittest::TestOneof2::BAR,
      reflection->GetEnum(*message, descriptor->FindFieldByName("bar_enum"))
          ->number());

  // Check set functions.
  TestUtil::ReflectionTester reflection_tester(oneof_descriptor_);
  reflection_tester.SetOneofViaReflection(message);
  reflection_tester.ExpectOneofSetViaReflection(*message);

  if (!use_arena()) {
    delete message;
  }
}

TEST_P(DynamicMessageTest, SpaceUsed) {
  // Test that SpaceUsedLong() works properly

  // Since we share the implementation with generated messages, we don't need
  // to test very much here.  Just make sure it appears to be working.

  Arena arena;
  Message* message = prototype_->New(use_arena() ? &arena : nullptr);
  TestUtil::ReflectionTester reflection_tester(descriptor_);

  size_t initial_space_used = message->SpaceUsedLong();

  reflection_tester.SetAllFieldsViaReflection(message);
  EXPECT_LT(initial_space_used, message->SpaceUsedLong());

  if (!use_arena()) {
    delete message;
  }
}

TEST_P(DynamicMessageTest, Arena) {
  Arena arena;
  Message* message = prototype_->New(&arena);
  Message* extension_message = extensions_prototype_->New(&arena);
  Message* packed_message = packed_prototype_->New(&arena);
  Message* oneof_message = oneof_prototype_->New(&arena);

  // avoid unused-variable error.
  (void)message;
  (void)extension_message;
  (void)packed_message;
  (void)oneof_message;
  // Return without freeing: should not leak.
}


TEST_P(DynamicMessageTest, Proto3) {
  Message* message = proto3_prototype_->New();
  const Reflection* refl = message->GetReflection();
  const Descriptor* desc = message->GetDescriptor();

  // Just test a single primitive and single message field here to make sure we
  // are getting the no-field-presence semantics elsewhere. DynamicMessage uses
  // GeneratedMessageReflection under the hood, so the rest should be fine as
  // long as GMR recognizes that we're using a proto3 message.
  const FieldDescriptor* optional_int32 =
      desc->FindFieldByName("optional_int32");
  const FieldDescriptor* optional_msg =
      desc->FindFieldByName("optional_nested_message");
  EXPECT_TRUE(optional_int32 != nullptr);
  EXPECT_TRUE(optional_msg != nullptr);

  EXPECT_EQ(false, refl->HasField(*message, optional_int32));
  refl->SetInt32(message, optional_int32, 42);
  EXPECT_EQ(true, refl->HasField(*message, optional_int32));
  refl->SetInt32(message, optional_int32, 0);
  EXPECT_EQ(false, refl->HasField(*message, optional_int32));

  EXPECT_EQ(false, refl->HasField(*message, optional_msg));
  refl->MutableMessage(message, optional_msg);
  EXPECT_EQ(true, refl->HasField(*message, optional_msg));
  delete refl->ReleaseMessage(message, optional_msg);
  EXPECT_EQ(false, refl->HasField(*message, optional_msg));

  // Also ensure that the default instance handles field presence properly.
  EXPECT_EQ(false, refl->HasField(*proto3_prototype_, optional_msg));

  delete message;
}

INSTANTIATE_TEST_SUITE_P(UseArena, DynamicMessageTest,
                         ::testing::Combine(::testing::Bool(),
                                            ::testing::Bool()));


}  // namespace
}  // namespace protobuf
}  // namespace google

#include "google/protobuf/port_undef.inc"
