private IEnumerable<Production> RemoveUnits(IEnumerable<Production> productions) {
var productionTable = new Dictionary<Production, Production>(new ProductionComparer());
foreach (var production in productions) {
productionTable[production] = production;
}
var oldSumOfProbability = double.MaxValue;
// we keep looping until the probability of any unit production has been driven to zero
// as an invariant, we make sure that the sum of the unit probabilities goes down each iteration
while (oldSumOfProbability > 0) {
// TODO: don't need to build this table every round
var productionsByNonterminal = GrammarHelpers.BuildLookupTable(productionTable.Keys);
var newSumOfProbability = 0.0;
var toAdd = new List<Production>();
var toRemove = new List<Production>();
// find all the unit productions and replace them with equivalent rules
// X -> Y gets replaced with rules X -> Z for all Y -> Z
foreach (var production in productionTable.Keys) {
if (production.IsUnit()) {
var thisProb = GetProbability(production, productionsByNonterminal);
if (double.IsNaN(thisProb)) {
continue;
}
newSumOfProbability += thisProb;
var replacements = UnitReplacementProductions(production, productionsByNonterminal);
toAdd.AddRange(replacements);
toRemove.Add(production);
}
}
if (oldSumOfProbability < newSumOfProbability) {
throw new Exception("Invariant didn't hold, we want probability sums to decrease every iteration");
}
oldSumOfProbability = newSumOfProbability;
foreach (var production in toRemove) {
production.Weight = 0.0;
}
MergeProductions(productionTable, toAdd);
}
return productionTable.Keys.Where((p) => p.Weight > 0.0);
}