組み合わせの等価性を比較するIEqualityComparerのGetHashCode

例えば以下のようなクラスがあったとして

class Pair
{
    public int Value1 { get; }
    public int Value2 { get; }

    public Pair(int value1, int value2)
    {
        Value1 = value1;
        Value2 = value2;
    }
}


Value1とValue2の組み合わせが等しいかを比較したくて次のようなIEqualityComparerを実装したとします。

class PairEqualityComparer : IEqualityComparer<Pair>
{
    public bool Equals(Pair x, Pair y)
    {
        return
            (x.Value1.Equals(y.Value1) && x.Value2.Equals(y.Value2)) ||
            (x.Value1.Equals(y.Value2) && x.Value2.Equals(y.Value1));
    }
}


このとき、GetHashCodeの実装はどうしたらいいのでしょうか。


もちろん、以下のコードはこれは意図した挙動になりません。

class PairEqualityComparer : IEqualityComparer<Pair>
{
    public bool Equals(Pair x, Pair y)
    {
        return
            (x.Value1.Equals(y.Value1) && x.Value2.Equals(y.Value2)) ||
            (x.Value1.Equals(y.Value2) && x.Value2.Equals(y.Value1));
    }

    public int GetHashCode(Pair obj)
    {
        obj.GetHashCode();
    }
}



Equalsが正しく実装されていれば、以下のようなコードを書いても想定通りに動きます。

class PairEqualityComparer : IEqualityComparer<Pair>
{
    public bool Equals(Pair x, Pair y)
    {
        return
            (x.Value1.Equals(y.Value1) && x.Value2.Equals(y.Value2)) ||
            (x.Value1.Equals(y.Value2) && x.Value2.Equals(y.Value1));
    }

    public int GetHashCode(Pair obj)
    {
        return 0;
    }
}

ただこれは線形サーチになるのでかなり遅いです。



私はいつも次のような実装をしています。(正しいかどうかは分かりません)

class PairEqualityComparer : IEqualityComparer<Pair>
{
    public bool Equals(Pair x, Pair y)
    {
        return
            (x.Value1.Equals(y.Value1) && x.Value2.Equals(y.Value2)) ||
            (x.Value1.Equals(y.Value2) && x.Value2.Equals(y.Value1));
    }

    public int GetHashCode(Pair obj)
    {
        var val1 = Math.Min(obj.Value1, obj.Value2);
        var val2 = Math.Max(obj.Value1, obj.Value2);

        var hash = 17;
        hash = hash * 23 + val1.GetHashCode();
        hash = hash * 23 + val2.GetHashCode();

        return hash;

    }
}

ソートしてからハッシュを計算します。
独自クラスの場合はIComparerを実装してソートします。



一応検証コードを載せときます。

class Pair
{
    public int Value1 { get; }
    public int Value2 { get; }

    public Pair(int value1, int value2)
    {
        Value1 = value1;
        Value2 = value2;
    }
}

class PairEqualityComparer : IEqualityComparer<Pair>
{
    public bool Equals(Pair x, Pair y)
    {
        return
            (x.Value1.Equals(y.Value1) && x.Value2.Equals(y.Value2)) ||
            (x.Value1.Equals(y.Value2) && x.Value2.Equals(y.Value1));
    }

    public int GetHashCode(Pair obj)
    {
        return obj.GetHashCode();
    }
}

class Program
{
    static void Main(string[] args)
    {
        var pairs = new HashSet<Pair>(new PairEqualityComparer());

        var sw = new Stopwatch();
        sw.Start();

        var count = 0;
        while (count < 100000)
        {
            pairs.Add(new Pair(count, 100000 - count));
            count++;
        }

        sw.Stop();
        Console.WriteLine(sw.Elapsed);
        Console.WriteLine(pairs.Count());
        Console.ReadLine();
    }
}

結果

00:00:00.0231028
100000


class Pair
{
    public int Value1 { get; }
    public int Value2 { get; }

    public Pair(int value1, int value2)
    {
        Value1 = value1;
        Value2 = value2;
    }
}

class PairEqualityComparer : IEqualityComparer<Pair>
{
    public bool Equals(Pair x, Pair y)
    {
        return
            (x.Value1.Equals(y.Value1) && x.Value2.Equals(y.Value2)) ||
            (x.Value1.Equals(y.Value2) && x.Value2.Equals(y.Value1));
    }

    public int GetHashCode(Pair obj)
    {
        return 0;
    }
}

class Program
{
    static void Main(string[] args)
    {
        var pairs = new HashSet<Pair>(new PairEqualityComparer());

        var sw = new Stopwatch();
        sw.Start();

        var count = 0;
        while (count < 100000)
        {
            pairs.Add(new Pair(count, 100000 - count));
            count++;
        }

        sw.Stop();
        Console.WriteLine(sw.Elapsed);
        Console.WriteLine(pairs.Count());
        Console.ReadLine();
    }
}

結果

00:01:25.4922908
50001


class Pair
{
    public int Value1 { get; }
    public int Value2 { get; }

    public Pair(int value1, int value2)
    {
        Value1 = value1;
        Value2 = value2;
    }
}

class PairEqualityComparer : IEqualityComparer<Pair>
{
    public bool Equals(Pair x, Pair y)
    {
        return
            (x.Value1.Equals(y.Value1) && x.Value2.Equals(y.Value2)) ||
            (x.Value1.Equals(y.Value2) && x.Value2.Equals(y.Value1));
    }

    public int GetHashCode(Pair obj)
    {
        var val1 = Math.Min(obj.Value1, obj.Value2);
        var val2 = Math.Max(obj.Value1, obj.Value2);

        var hash = 17;
        hash = hash * 23 + val1.GetHashCode();
        hash = hash * 23 + val2.GetHashCode();

        return hash;
    }
}

class Program
{
    static void Main(string[] args)
    {
        var pairs = new HashSet<Pair>(new PairEqualityComparer());

        var sw = new Stopwatch();
        sw.Start();

        var count = 0;
        while (count < 100000)
        {
            pairs.Add(new Pair(count, 100000 - count));
            count++;
        }

        sw.Stop();
        Console.WriteLine(sw.Elapsed);
        Console.WriteLine(pairs.Count());
        Console.ReadLine();
    }
}

結果

00:00:00.0221448
50001