LoginSignup
0
1

More than 3 years have passed since last update.

Linq の ~OrDefault 拡張

Posted at

これはなに

Linq の FirstOrDefault 系メソッドに、見つからなければ返される default 値を指定できるようにした拡張メソッド。

モチベーション

列挙体 eHoge があるときに、

return new eHoge[]{...}.FirstOrDefault(_ => {...}) ?? OnNotFound();

で、コンパイルエラーになる。列挙体は値型なので、あたりまえっちゃ当たり前なんだけど、そうすると

var r = new eHoge[]{...}.FirstOrDefault(_ => {...});
return r != default ? r : OnNotFound();

とか核ハメになって、それよりも

return new eHoge[]{...}.FirstOrDefault(_ => {...}, OnNotFound());

のように書けたほうが便利じゃねってなった。

コード

using System;
using System.Collections.Generic;
using System.Text;

namespace System.Linq
{
    public static partial class Enumerable
    {
        public static TSource FirstOrDefault<TSource>(this IEnumerable<TSource> source, TSource defaultValue)
        {
            if (source == null) throw new ArgumentNullException(nameof(source));
            //Error.ArgumentNull("source");
            IList<TSource> list = source as IList<TSource>;
            if (list != null)
            {
                if (list.Count > 0) return list[0];
            }
            else
            {
                using (IEnumerator<TSource> e = source.GetEnumerator())
                {
                    if (e.MoveNext()) return e.Current;
                }
            }
            return defaultValue;
        }

        public static TSource FirstOrDefault<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate, TSource defaultValue)
        {
            if (source == null) throw new ArgumentNullException(nameof(source));
            if (predicate == null) throw new ArgumentNullException(nameof(predicate));
            //if (source == null) throw Error.ArgumentNull("source");
            //if (predicate == null) throw Error.ArgumentNull("predicate");
            foreach (TSource element in source)
            {
                if (predicate(element)) return element;
            }
            return defaultValue;
        }

        public static TSource LastOrDefault<TSource>(this IEnumerable<TSource> source, TSource defaultValue)
        {
            if (source == null) throw new ArgumentNullException(nameof(source));
            //if (source == null) throw Error.ArgumentNull("source");
            IList<TSource> list = source as IList<TSource>;
            if (list != null)
            {
                int count = list.Count;
                if (count > 0) return list[count - 1];
            }
            else
            {
                using (IEnumerator<TSource> e = source.GetEnumerator())
                {
                    if (e.MoveNext())
                    {
                        TSource result;
                        do
                        {
                            result = e.Current;
                        } while (e.MoveNext());
                        return result;
                    }
                }
            }
            return defaultValue;
        }

        public static TSource LastOrDefault<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate, TSource defaultValue)
        {
            if (source == null) throw new ArgumentNullException(nameof(source));
            if (predicate == null) throw new ArgumentNullException(nameof(predicate));
            //if (source == null) throw Error.ArgumentNull("source");
            //if (predicate == null) throw Error.ArgumentNull("predicate");
            TSource result = defaultValue;
            foreach (TSource element in source)
            {
                if (predicate(element))
                {
                    result = element;
                }
            }
            return result;
        }


        public static TSource SingleOrDefault<TSource>(this IEnumerable<TSource> source, TSource defaultValue)
        {
            if (source == null) throw new ArgumentNullException(nameof(source));
            //if (source == null) throw Error.ArgumentNull("source");
            IList<TSource> list = source as IList<TSource>;
            if (list != null)
            {
                switch (list.Count)
                {
                    case 0: return defaultValue;
                    case 1: return list[0];
                }
            }
            else
            {
                using (IEnumerator<TSource> e = source.GetEnumerator())
                {
                    if (!e.MoveNext()) return defaultValue;
                    TSource result = e.Current;
                    if (!e.MoveNext()) return result;
                }
            }
            throw new InvalidOperationException("MoreThanOneElement");
            // throw Error.MoreThanOneElement();
        }

        public static TSource SingleOrDefault<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate, TSource defaultValue)
        {
            if (source == null) throw new ArgumentNullException(nameof(source));
            if (predicate == null) throw new ArgumentNullException(nameof(predicate));
            //if (source == null) throw Error.ArgumentNull("source");
            //if (predicate == null) throw Error.ArgumentNull("predicate");
            TSource result = defaultValue;
            long count = 0;
            foreach (TSource element in source)
            {
                if (predicate(element))
                {
                    result = element;
                    checked { count++; }
                }
            }
            switch (count)
            {
                case 0: return defaultValue;
                case 1: return result;
            }
            throw new InvalidOperationException("MoreThanOneMatch");
            // throw Error.MoreThanOneMatch();
        }

        public static TSource ElementAtOrDefault<TSource>(this IEnumerable<TSource> source, int index, TSource defaultValue)
        {
            if (source == null) throw new ArgumentNullException(nameof(source));
            //if (source == null) throw Error.ArgumentNull("source");
            if (index >= 0)
            {
                IList<TSource> list = source as IList<TSource>;
                if (list != null)
                {
                    if (index < list.Count) return list[index];
                }
                else
                {
                    using (IEnumerator<TSource> e = source.GetEnumerator())
                    {
                        while (true)
                        {
                            if (!e.MoveNext()) break;
                            if (index == 0) return e.Current;
                            index--;
                        }
                    }
                }
            }
            return defaultValue;
        }
    }
}

github のソースコードから ~OrDefault 系のメソッドを取ってきて、default(T) ってなっているところを defaultValue に変更しただけ。

正確に言えば defaultValue 引数の型を TSource にしてしまっていることで、上記モチベにおける OnNotFound() の評価順序が Source からの検索後から検索前に移ってしまっている。
それが問題になる場合は defaultValue を受け取る代わりに Func<TSource> を受け取って defaultValue を返す代わりに受け取った Func<TSource> を呼び出す関数を追加。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1