// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "hwy/bit_set.h"

#include <stddef.h>
#include <stdint.h>
#include <stdio.h>

#include <algorithm>  // std::find
#include <map>
#include <utility>  // std::make_pair
#include <vector>

#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h"
#include "hwy/tests/test_util.h"

namespace hwy {
namespace {

template <class Set>
void SmokeTest() {
  constexpr size_t kMax = Set().MaxSize() - 1;

  Set set;
  // Defaults to empty.
  HWY_ASSERT(!set.Any());
  HWY_ASSERT(!set.All());
  HWY_ASSERT(!set.Get(0));
  HWY_ASSERT(!set.Get(kMax));
  HWY_ASSERT(set.First0() == 0);
  set.Foreach(
      [](size_t i) { HWY_ABORT("Set should be empty but got %zu\n", i); });
  HWY_ASSERT(set.Count() == 0);

  // After setting, we can retrieve it.
  set.Set(kMax);
  HWY_ASSERT(set.Get(kMax));
  HWY_ASSERT(set.Any());
  HWY_ASSERT(!set.All());
  HWY_ASSERT(set.First() == kMax);
  HWY_ASSERT(set.First0() == 0);
  set.Foreach([](size_t i) { HWY_ASSERT(i == kMax); });
  HWY_ASSERT(set.Count() == 1);

  // After clearing, it is empty again.
  set.Clear(kMax);
  set.Clear(0);  // was not set
  HWY_ASSERT(!set.Get(0));
  HWY_ASSERT(!set.Get(kMax));
  HWY_ASSERT(!set.Any());
  HWY_ASSERT(!set.All());
  HWY_ASSERT(set.First0() == 0);
  set.Foreach(
      [](size_t i) { HWY_ABORT("Set should be empty but got %zu\n", i); });
  HWY_ASSERT(set.Count() == 0);
}

TEST(BitSetTest, SmokeTestSet64) { SmokeTest<BitSet64>(); }
TEST(BitSetTest, SmokeTestSet) { SmokeTest<BitSet<320>>(); }
TEST(BitSetTest, SmokeTestAtomicSet) { SmokeTest<AtomicBitSet<400>>(); }
TEST(BitSetTest, SmokeTestSet4096) { SmokeTest<BitSet4096<>>(); }

template <class Set>
void TestSetNonzeroBitsFrom64() {
  constexpr size_t kMin = 0;
  Set set;
  set.SetNonzeroBitsFrom64(1ull << kMin);
  HWY_ASSERT(set.Any());
  HWY_ASSERT(!set.All());
  HWY_ASSERT(set.Get(kMin));
  HWY_ASSERT(set.First() == kMin);
  HWY_ASSERT(set.First0() == kMin + 1);
  set.Foreach([](size_t i) { HWY_ASSERT(i == kMin); });
  HWY_ASSERT(set.Count() == 1);

  set.SetNonzeroBitsFrom64(0x70ULL);
  HWY_ASSERT(set.Get(kMin) && set.Get(4) && set.Get(5) && set.Get(6));
  HWY_ASSERT(set.Any());
  HWY_ASSERT(!set.All());
  HWY_ASSERT(set.First() == kMin);  // does not clear existing bits
  HWY_ASSERT(set.First0() == kMin + 1);
  set.Foreach([](size_t i) { HWY_ASSERT(i == kMin || (4 <= i && i <= 6)); });
  HWY_ASSERT(set.Count() == 4);
}

TEST(BitSetTest, TestSetNonzeroBits64) { TestSetNonzeroBitsFrom64<BitSet64>(); }
TEST(BitSetTest, TestSetNonzeroBits4096) {
  TestSetNonzeroBitsFrom64<BitSet4096<>>();
}

// Reference implementation using map (for sparse `BitSet4096`) and vector for
// random choice of elements.
class SlowSet {
 public:
  // Inserting multiple times is a no-op.
  void Set(size_t i) {
    const auto ib = idx_for_i_.insert(std::make_pair(i, vec_.size()));
    if (ib.second) {  // inserted
      vec_.push_back(i);
      HWY_ASSERT(idx_for_i_.size() == vec_.size());
    } else {
      // Already have `i` and it can be found at the stored index.
      HWY_ASSERT(ib.first->first == i);
      const size_t idx = ib.first->second;
      HWY_ASSERT(vec_[idx] == i);
    }
    HWY_ASSERT(Get(i));
  }

  bool Get(size_t i) const {
    const auto it = idx_for_i_.find(i);
    if (it == idx_for_i_.end()) {
      HWY_ASSERT(std::find(vec_.begin(), vec_.end(), i) == vec_.end());
      return false;
    }
    HWY_ASSERT(vec_[it->second] == i);
    return true;
  }

  void Clear(size_t i) {
    if (!Get(i)) return;
    const size_t idx = idx_for_i_[i];
    idx_for_i_.erase(i);
    // Move last into gap, unless it was equal to `i`.
    const size_t last = vec_.back();
    vec_.pop_back();
    if (last == i) {
      HWY_ASSERT(idx == vec_.size());  // was the last item
    } else {
      HWY_ASSERT(vec_[idx] == i);
      vec_[idx] = last;
      idx_for_i_[last] = idx;
      HWY_ASSERT(Get(last));  // can still find `last`
    }
    HWY_ASSERT(!Get(i));
  }

  size_t Count() const {
    HWY_ASSERT(idx_for_i_.size() == vec_.size());
    return vec_.size();
  }

  // Must not call if Count() == 0.
  size_t RandomChoice(RandomState& rng) const {
    HWY_ASSERT(Count() != 0);
    const size_t idx = static_cast<size_t>(hwy::Random32(&rng)) % vec_.size();
    return vec_[idx];
  }

  template <class Set>
  void CheckSame(const Set& set) {
    HWY_ASSERT(set.Any() == (set.Count() != 0));
    HWY_ASSERT(set.All() == (set.Count() == set.MaxSize()));
    HWY_ASSERT(Count() == set.Count());
    // Everything set has, we also have.
    set.Foreach([this](size_t i) { HWY_ASSERT(Get(i)); });
    // Everything we have, set also has.
    std::for_each(vec_.begin(), vec_.end(),
                  [&set](size_t i) { HWY_ASSERT(set.Get(i)); });
    // First matches first in the map
    if (set.Any()) {
      HWY_ASSERT(set.First() == idx_for_i_.begin()->first);
    }
    if (!set.All()) {
      const size_t idx0 = set.First0();
      HWY_ASSERT(idx0 < set.MaxSize());
      HWY_ASSERT(!set.Get(idx0));
      HWY_ASSERT(!Get(idx0));
    }
  }

 private:
  std::vector<size_t> vec_;
  std::map<size_t, size_t> idx_for_i_;
};

template <class Set>
void TestSetWithGrowProb(uint64_t grow_prob) {
  constexpr uint32_t max_size = static_cast<uint32_t>(Set().MaxSize());
  RandomState rng;

  // Multiple independent random tests:
  for (size_t rep = 0; rep < AdjustedReps(100); ++rep) {
    Set set;
    SlowSet slow_set;
    // Mutate sets via random walk and ensure they are the same afterwards.
    for (size_t iter = 0; iter < AdjustedReps(1000); ++iter) {
      const uint64_t bits = (Random64(&rng) >> 10) & 0x3FF;
      if (bits > 980 && slow_set.Count() != 0) {
        // Small chance of reinsertion: already present, unchanged after.
        const size_t i = slow_set.RandomChoice(rng);
        const size_t count = set.Count();
        HWY_ASSERT(set.Get(i));
        slow_set.Set(i);
        set.Set(i);
        HWY_ASSERT(set.Get(i));
        HWY_ASSERT(count == set.Count());
      } else if (bits < grow_prob) {
        // Set random value; no harm if already set.
        const size_t i = static_cast<size_t>(Random32(&rng) % max_size);
        slow_set.Set(i);
        set.Set(i);
        HWY_ASSERT(set.Get(i));
      } else if (slow_set.Count() != 0) {
        // Remove existing item.
        const size_t i = slow_set.RandomChoice(rng);
        const size_t count = set.Count();
        HWY_ASSERT(set.Get(i));
        slow_set.Clear(i);
        set.Clear(i);
        HWY_ASSERT(!set.Get(i));
        HWY_ASSERT(count == set.Count() + 1);
      }
    }
    slow_set.CheckSame(set);
  }
}

template <class Set>
void TestSetRandom() {
  // Lower probability of growth so that the set is often nearly empty.
  TestSetWithGrowProb<Set>(400);

  TestSetWithGrowProb<Set>(600);
}

TEST(BitSetTest, TestSet64) { TestSetRandom<BitSet64>(); }
TEST(BitSetTest, TestSet41) { TestSetRandom<BitSet<41>>(); }
TEST(BitSetTest, TestSet) { TestSetRandom<BitSet<199>>(); }
// One partial u64
TEST(BitSetTest, TestAtomicSet32) { TestSetRandom<AtomicBitSet<32>>(); }
// 3 whole u64
TEST(BitSetTest, TestAtomicSet192) { TestSetRandom<AtomicBitSet<192>>(); }
TEST(BitSetTest, TestSet3000) { TestSetRandom<BitSet4096<3000>>(); }
TEST(BitSetTest, TestSet4096) { TestSetRandom<BitSet4096<>>(); }

}  // namespace
}  // namespace hwy

HWY_TEST_MAIN();
