事情的經過是這樣的:
我用C#寫了一個很簡單的一個通過疊代生成序列的函數。
public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
Checker.NullCheck(nameof(f), f);
Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
var current = initVal;
while (--length >= 0)
{
yield return (current = f(current));
}
}
其中NullCheck用于檢查參數是否為null,如果是則抛出ArgumentNullException異常。
對應的,我寫了如下單元測試代碼去檢測這個異常。
public void TestIterate()
{
Func<int, int> f = null;
Assert.Throws<ArgumentNullException>(() => f.Iterate(1, 7));
// Other tests
}
但是,這個測試出乎意料的fail了。
一開始,我以為是NullCheck函數的問題,可我把NullCheck直接換成了if語句,還是通不過。
後來我在Iterate函數下斷點并調試。結果調試器根本沒有停在斷點上,直接運作完了測試。
我以為是我測試的方法不對,是以我不斷的修改測試代碼,甚至還一度以為是.NET的Unit Tests出了bug。
最終,我在這個測試代碼發現了問題:
Assert.Throws<ArgumentNullException>(() =>
{
var seq = f.Iterate(1, 7);
foreach (int ele in seq)
Console.WriteLine(ele);
});
當我調試這個測試時,程式停在了我之前在Iterate函數上下的斷點。
于是,我在 var seq = f.Iterate(1, 7); 上下斷點,并逐漸運作。這時我發現,當程式運作到 var seq = f.Iterate(1, 7); 時并不會進入Iterate函數;而是當程式運作到foreach語句後才進入。
這就要涉及到yield return的具體工作流程。當函數代碼中出現yield return,調用這個函數會直接傳回一個IEnumerable<T>或IEnumerator<T>對象,并不會執行函數體的任何代碼。這些代碼都被封裝到了傳回對象的内部。它們會在你開始枚舉的時候開始執行。
是以,上面兩個Check并不會在函數調用時執行,而是在當你開始foreach的時候才執行。
這并不是我想要的結果。我希望在調用函數時就檢查參數合法性,如果不合法便直接抛出異常。
解決這個問題有兩種途徑,一是把它拆成兩個函數:
public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
Checker.NullCheck(nameof(f), f);
Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
return IterateWithoutCheck(f, initVal, length);
}
private static IEnumerable<T> IterateWithoutCheck<T>(this Func<T, T> f, T initVal, int length)
{
var current = initVal;
while (--length >= 0)
{
yield return (current = f(current));
}
}
或者,你也可以将這個函數包裝成一個類。
class FunctionIterator<T> : IEnumerable<T>
{
private readonly Func<T, T> f;
private readonly T initVal;
private readonly int length;
public FunctionIterator(Func<T, T> f, T initVal, int length)
{
Checker.NullCheck(nameof(f), f);
Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
this.f = f;
this.initVal = initVal;
this.length = length;
}
public IEnumerator<T> GetEnumerator()
{
T current = initVal;
for (int i = 0; i < length; ++i)
yield return (current = f(current));
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}