using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

using MongoDB.Bson;
using MongoDB.Bson.IO;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver;
using MongoDB.Driver.Linq;

namespace TestStringIdSerializer
{
public class Foo
{
public int Id { get; set; }
public string S { get; set; }

public override string ToString()
{
return string.Format("Foo:Id={0},S=\"{1}\"", Id, S);
}
}

public class StringIdSerializer : BsonBaseSerializer
{
private StringTable _stringTable;

public StringIdSerializer(StringTable stringTable)
{
_stringTable = stringTable;
}

public override object Deserialize(BsonReader bsonReader, Type nominalType, Type actualType, IBsonSerializationOptions options)
{
var id = bsonReader.ReadInt32();
return _stringTable.LookupString(id);
}

public override void Serialize(BsonWriter bsonWriter, Type nominalType, object value, IBsonSerializationOptions options)
{
var id = _stringTable.LookupId((string)value);
bsonWriter.WriteInt32(id);
}
}

public class StringTable
{
private int _nextId;
private Dictionary<string, int> _ids = new Dictionary<string, int>();
private Dictionary<int, string> _strings= new Dictionary<int,string>();

public int LookupId(string s)
{
int id;
if (!_ids.TryGetValue(s, out id))
{
id = ++_nextId;
_ids.Add(s, id);
_strings.Add(id, s);
}
return id;
}

public string LookupString(int id)
{
string s;
if (!_strings.TryGetValue(id, out s))
{
var message = string.Format("Id {0} not found in string table.", id);
throw new ArgumentOutOfRangeException(message, "id");
}
return s;
}
}

public static class Program
{
public static void Main(string[] _args)
{
try
{
var stringTable = new StringTable();
var stringIdSerializer = new StringIdSerializer(stringTable);
BsonClassMap.RegisterClassMap<Foo>(cm =>
{
cm.AutoMap();
cm.GetMemberMap(x => x.S).SetSerializer(stringIdSerializer);
});

var connectionString = "mongodb://localhost/?safe=true";
var server = MongoServer.Create(connectionString);
var database = server.GetDatabase("test");
var collection = database.GetCollection<Foo>("test");

if (collection.Exists()) { collection.Drop(); }
collection.Insert(new Foo { Id = 1, S = "abc" });
collection.Insert(new Foo { Id = 2, S = "def" });

Console.WriteLine("Raw BSON documents inserted:");
foreach (var document in collection.FindAllAs<BsonDocument>())
{
Console.WriteLine(document.ToJson());
}
Console.WriteLine();

Console.WriteLine("Foo instances read back:");
foreach (var foo in collection.FindAll())
{
Console.WriteLine(foo.ToString());
}
Console.WriteLine();

Console.WriteLine("Foo instances read back using LINQ query for \"def\":");
var query = collection.AsQueryable().Where(x => x.S == "def");
Console.WriteLine("MongoDB query: {0}", ((MongoQueryable<Foo>)query).GetMongoQuery().ToJson());
foreach (var foo in query)
{
Console.WriteLine(foo.ToString());
}
Console.WriteLine();
}
catch (Exception ex)
{
Console.WriteLine("Unhandled exception:");
Console.WriteLine(ex);
}

Console.WriteLine("Press Enter to continue");
Console.ReadLine();
}
}
}