天天看點

記一次被yield return坑的曆程。

事情的經過是這樣的:

我用C#寫了一個很簡單的一個通過疊代生成序列的函數。

public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
    Checker.NullCheck(nameof(f), f);    
    Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);

    var current = initVal;
    while (--length >= 0)
    {
        yield return (current = f(current));
    }
}      

其中NullCheck用于檢查參數是否為null,如果是則抛出ArgumentNullException異常。

對應的,我寫了如下單元測試代碼去檢測這個異常。

public void TestIterate()
{
    Func<int, int> f = null;
    Assert.Throws<ArgumentNullException>(() => f.Iterate(1, 7));
    
    // Other tests
}      

但是,這個測試出乎意料的fail了。

一開始,我以為是NullCheck函數的問題,可我把NullCheck直接換成了if語句,還是通不過。

後來我在Iterate函數下斷點并調試。結果調試器根本沒有停在斷點上,直接運作完了測試。

我以為是我測試的方法不對,是以我不斷的修改測試代碼,甚至還一度以為是.NET的Unit Tests出了bug。

最終,我在這個測試代碼發現了問題:

Assert.Throws<ArgumentNullException>(() =>
{
    var seq = f.Iterate(1, 7);
    foreach (int ele in seq)
        Console.WriteLine(ele);
});      

當我調試這個測試時,程式停在了我之前在Iterate函數上下的斷點。

于是,我在 var seq = f.Iterate(1, 7); 上下斷點,并逐漸運作。這時我發現,當程式運作到 var seq = f.Iterate(1, 7); 時并不會進入Iterate函數;而是當程式運作到foreach語句後才進入。

這就要涉及到yield return的具體工作流程。當函數代碼中出現yield return,調用這個函數會直接傳回一個IEnumerable<T>或IEnumerator<T>對象,并不會執行函數體的任何代碼。這些代碼都被封裝到了傳回對象的内部。它們會在你開始枚舉的時候開始執行。

是以,上面兩個Check并不會在函數調用時執行,而是在當你開始foreach的時候才執行。

這并不是我想要的結果。我希望在調用函數時就檢查參數合法性,如果不合法便直接抛出異常。

解決這個問題有兩種途徑,一是把它拆成兩個函數:

public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
    Checker.NullCheck(nameof(f), f);    
    Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
            
    return IterateWithoutCheck(f, initVal, length);
}

private static IEnumerable<T> IterateWithoutCheck<T>(this Func<T, T> f, T initVal, int length)
{
    var current = initVal;
    while (--length >= 0)
    {
        yield return (current = f(current));
    }
}      

或者,你也可以将這個函數包裝成一個類。

class FunctionIterator<T> : IEnumerable<T>
{
    private readonly Func<T, T> f;
    private readonly T initVal;
    private readonly int length;
        
    public FunctionIterator(Func<T, T> f, T initVal, int length)
    {
        Checker.NullCheck(nameof(f), f);
        Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
            
        this.f = f;
        this.initVal = initVal;
        this.length = length;
    }

    public IEnumerator<T> GetEnumerator()
    {
        T current = initVal;

        for (int i = 0; i < length; ++i)
            yield return (current = f(current));
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return GetEnumerator();
    }
}