用动态规划实现文本对齐

时间:2013-08-13 04:17:45

标签: algorithm python-3.x dynamic-programming

我试图通过麻省理工学院开放式课程here的课程来理解动态规划的概念。关于OCW视频的解释很棒,但我觉得在我将解释实现到代码之前我并不理解它。在实施时,我会参考讲义here中的一些注释,特别是注释的第3页。

问题是,我不知道如何将一些数学符号转换为代码。这是我实施的解决方案的一部分(并认为它是正确实现的):

import math

paragraph = "Some long lorem ipsum text."
words = paragraph.split(" ")

# Count total length for all strings in a list of strings.
# This function will be used by the badness function below.
def total_length(str_arr):
    total = 0

    for string in str_arr:
        total = total + len(string)

    total = total + len(str_arr) # spaces
    return total

# Calculate the badness score for a word.
# str_arr is assumed be send as word[i:j] as in the notes
# we don't make i and j as argument since it will require
# global vars then.
def badness(str_arr, page_width):
    line_len = total_length(str_arr)
    if line_len > page_width:
        return float('nan') 
    else:
        return math.pow(page_width - line_len, 3)

现在我不理解的部分是在讲义中的第3点到第5点。我实际上不明白,也不知道从哪里开始实现这些。到目前为止,我已经尝试迭代单词列表,并计算每个所谓的行尾的坏处,如下所示:

def justifier(str_arr, page_width):
    paragraph = str_arr
    par_len = len(paragraph)
    result = [] # stores each line as list of strings
    for i in range(0, par_len):
        if i == (par_len - 1):
            result.append(paragraph)
        else:
            dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 
            # Should I do a min(dag), get the index, and declares it as end of line?

但是,我不知道如何继续这项功能,说实话,我不理解这一行:

dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 

以及如何将justifier作为int返回(因为我已经决定将返回值存储在result中,这是一个列表。我应该创建另一个函数并从中回复那里?应该有任何递归吗?

你能告诉我下一步该怎么做,并解释一下这是怎么做的动态编程吗?我真的看不出递归的位置,以及子问题是什么。

先谢谢。

5 个答案:

答案 0 :(得分:19)

如果您无法理解动态编程本身的核心思想,请参考以下内容:

动态编程本质上牺牲了空间复杂度 时间复杂度(但是你使用的额外空间通常非常与你的时间相比很少保存,使动态编程完全值得,如果正确实施)。您可以随时存储每个递归调用的值(例如,在数组或字典中),这样当您在递归树的另一个分支中遇到相同的递归调用时,可以避免第二次计算。

不,你必须使用递归。以下是我正在使用循环的问题的实现。我非常密切地关注了由AlexSilva链接的TextAlignment.pdf。希望你觉得这很有帮助。

def length(wordLengths, i, j): return sum(wordLengths[i- 1:j]) + j - i + 1 def breakLine(text, L): # wl = lengths of words wl = [len(word) for word in text.split()] # n = number of words in the text n = len(wl) # total badness of a text l1 ... li m = dict() # initialization m[0] = 0 # auxiliary array s = dict() # the actual algorithm for i in range(1, n + 1): sums = dict() k = i while (length(wl, k, i) <= L and k > 0): sums[(L - length(wl, k, i))**3 + m[k - 1]] = k k -= 1 m[i] = min(sums) s[i] = sums[min(sums)] # actually do the splitting by working backwords line = 1 while n > 1: print("line " + str(line) + ": " + str(s[n]) + "->" + str(n)) n = s[n] - 1 line += 1

答案 1 :(得分:4)

对于仍然对此感兴趣的任何人:关键是从文本末尾向后移动(如上所述here)。 如果你这样做,你只需比较已经记忆的元素。

说,words是要根据textwidth包装的字符串列表。然后,在讲座的符号中,任务减少到三行代码:

import numpy as np

textwidth = 80

DP = [0]*(len(words)+1)

for i in range(len(words)-1,-1,-1):
    DP[i] = np.min([DP[j] + badness(words[i:j],textwidth) for j in range(i+1,len(words)+1)])

使用:

def badness(line,textwidth):

    # Number of gaps
    length_line = len(line) - 1

    for word in line:
        length_line += len(word)

    if length_line > textwidth: return float('inf')

    return ( textwidth - length_line )**3

他提到可以添加第二个列表以跟踪违规位置。您可以通过将代码更改为:

来实现
DP = [0]*(len(words)+1)
breaks = [0]*(len(words)+1)

for i in range(len(words)-1,-1,-1):
    temp = [DP[j] + badness(words[i:j],args.textwidth) for j in range(i+1,len(words)+1)]

    index = np.argmin(temp)

    # Index plus position in upper list
    breaks[i] = index + i + 1
    DP[i] = temp[index]

要恢复文本,只需使用中断位置列表:

def reconstruct_text(words,breaks):                                                                                                                

    lines = []
    linebreaks = []

    i = 0 
    while True:

        linebreaks.append(breaks[i])
        i = breaks[i]

        if i == len(words):
            linebreaks.append(0)
            break

    for i in range( len(linebreaks) ):
        lines.append( ' '.join( words[ linebreaks[i-1] : linebreaks[i] ] ).strip() )

    return lines

结果:(text = reconstruct_text(words,breaks)

Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy
eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam
voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet
clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit
amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam
nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed
diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet
clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.

有人可能会想要添加一些空格。这非常棘手(因为人们可能会提出各种美学规则)但是天真的尝试可能是:

import re

def spacing(text,textwidth,maxspace=4):

    for i in range(len(text)):

        length_line = len(text[i])

        if length_line < textwidth:

            status_length = length_line
            whitespaces_remain = textwidth - status_length
            Nwhitespaces = text[i].count(' ')

            # If whitespaces (to add) per whitespace exeeds
            # maxspace, don't do anything.
            if whitespaces_remain/Nwhitespaces > maxspace-1:pass
            else:
                text[i] = text[i].replace(' ',' '*( 1 + int(whitespaces_remain/Nwhitespaces)) )
                status_length = len(text[i])

                # Periods have highest priority for whitespace insertion
                periods = text[i].split('.')

                # Can we add a whitespace behind each period?
                if len(periods) - 1 + status_length <= textwidth:
                    text[i] = '. '.join(periods).strip()

                status_length = len(text[i])
                whitespaces_remain = textwidth - status_length
                Nwords = len(text[i].split())
                Ngaps = Nwords - 1

                if whitespaces_remain != 0:factor = Ngaps / whitespaces_remain

                # List of whitespaces in line i
                gaps = re.findall('\s+', text[i])

                temp = text[i].split()
                for k in range(Ngaps):
                    temp[k] = ''.join([temp[k],gaps[k]])

                for j in range(whitespaces_remain):
                    if status_length >= textwidth:pass
                    else:
                        replace = temp[int(factor*j)]
                        replace = ''.join([replace, " "])
                        temp[int(factor*j)] = replace

                text[i] = ''.join(temp)

    return text

是什么让你:(text = spacing(text,textwidth)

Lorem  ipsum  dolor  sit  amet, consetetur  sadipscing  elitr,  sed  diam nonumy
eirmod  tempor  invidunt  ut labore  et  dolore  magna aliquyam  erat,  sed diam
voluptua.   At  vero eos  et accusam  et justo  duo dolores  et ea  rebum.  Stet
clita  kasd  gubergren,  no  sea  takimata sanctus  est  Lorem  ipsum  dolor sit
amet.   Lorem  ipsum  dolor  sit amet,  consetetur  sadipscing  elitr,  sed diam
nonumy  eirmod  tempor invidunt  ut labore  et dolore  magna aliquyam  erat, sed
diam  voluptua.  At vero eos et accusam et  justo duo dolores et ea rebum.  Stet
clita  kasd gubergren, no sea  takimata sanctus est Lorem  ipsum dolor sit amet.

答案 2 :(得分:1)

我刚看到讲座,思想就会放在这里,无论我能理解什么。我已经以与提问者类似的格式输入代码。正如讲座所解释的那样,我在这里使用了递归。
第3点,定义了重复。这基本上是一个底线,在这里你计算一个先前与较高输入有关的函数的值,然后用它来计算低值输入。
讲座解释如下:
DP(i)= min(DP(j)+不良(i,j))
对于j,其从i + 1到n变化。
在这里,我从n变化到0(从下到上!) 当DP(n)= 0时,
DP(n-1)= DP(n)+不良(n-1,n)
然后从D(n-1)和D(n)计算D(n-2)并从中取出最小值。
通过这种方式,你可以直到i = 0,这是不好的最终答案!
在第4点,你可以看到,这里有两个循环。一个用于我,另一个在我用于j。
因此,当i = 0时,j(max)= n,i = 1,j(max)= n-1,... i = n,j(max)= 0。 因此总时间=这些的加法= n(n + 1)/ 2。
因此O(n ^ 2)。
点#5只是确定DP [0]的解决方案!
希望这可以帮助!

import math

justification_map = {}
min_map = {}

def total_length(str_arr):
    total = 0

    for string in str_arr:
        total = total + len(string)

    total = total + len(str_arr) - 1 # spaces
    return total

def badness(str_arr, page_width):
    line_len = total_length(str_arr)
    if line_len > page_width:
        return float('nan') 
    else:
        return math.pow(page_width - line_len, 3)

def justify(i, n, words, page_width):
    if i == n:

        return 0
    ans = []
    for j in range(i+1, n+1):
        #ans.append(justify(j, n, words, page_width)+ badness(words[i:j], page_width))
        ans.append(justification_map[j]+ badness(words[i:j], page_width))
    min_map[i] = ans.index(min(ans)) + 1
    return min(ans)

def main():
    print "Enter page width"
    page_width = input()
    print "Enter text"
    paragraph = input() 
    words = paragraph.split(' ')
    n = len(words)
    #justification_map[n] = 0 
    for i in reversed(range(n+1)):
        justification_map[i] = justify(i, n, words, page_width)

    print "Minimum badness achieved: ", justification_map[0]

    key = 0
    while(key <n):
        key = key + min_map[key]
        print key

if __name__ == '__main__':
    main()

答案 3 :(得分:0)

根据您的定义,这就是我的想法。

import math

class Text(object):
    def __init__(self, words, width):
        self.words = words
        self.page_width = width
        self.str_arr = words
        self.memo = {}

    def total_length(self, str):
        total = 0
        for string in str:
            total = total + len(string)
        total = total + len(str) # spaces
        return total

    def badness(self, str):
        line_len = self.total_length(str)
        if line_len > self.page_width:
            return float('nan') 
        else:
            return math.pow(self.page_width - line_len, 3)

    def dp(self):
        n = len(self.str_arr)
        self.memo[n-1] = 0

        return self.judge(0)

    def judge(self, i):
        if i in self.memo:
            return self.memo[i]

        self.memo[i] = float('inf') 
        for j in range(i+1, len(self.str_arr)):
            bad = self.judge(j) + self.badness(self.str_arr[i:j])
            if bad < self.memo[i]:
                self.memo[i] = bad

        return self.memo[i]

答案 4 :(得分:0)

Java实现 给定最大线宽为L,证明文本T合理的想法是考虑文本的所有后缀(为了精确地形成后缀,请考虑使用单词而不是字符)。 动态编程不过是“谨慎的蛮力”。 如果您考虑采用蛮力方法,则需要执行以下操作。

  1. 考虑在第一行中放入1,2,.. n个单词。
  2. 对于在情况1中描述的每种情况(假设i个单词放在第1行中),请考虑在第2行中放置1,2,.. n -i个单词,然后在第3行中剩余单词的情况,依此类推。 。

相反,我们只考虑问题,以找出将单词放在行首的成本。 通常,我们可以将DP(i)定义为将第(i-1)个单词视为行的开头的代价。

如何为DP(i)形成递归关系?

如果第j个单词是下一行的开头,则当前行将包含单词[i:j)(j排他),而第j个单词作为下一行的开头的开销将为DP(j)。 因此,DP(i)= DP(j)+在当前行中放置单词[i:j)的成本 由于我们希望使总成本最小化,因此可以定义DP(i)。

重复关系:

  

DP(i)=最小{DP(j)+在当前行中放置单词[i:j的成本}   对于[i + 1,n]中的所有j

注意j = n表示下一行没有剩余的单词。

基本情况:DP(n)= 0 =>此时,没有可写的字了。

总结:

  1. 子问题:后缀,单词[:i]
  2. 猜:从下一行开始,选择项n-i-> O(n)
  3. 重复发生:DP(i)=最小值{DP(j)+在当前行中放置单词[i:j)的成本} 如果我们使用备忘录,则花括号内的表达式应该花费O(1)时间,并且循环运行O(n)次(选择次数#)。 i从n变化到0 =>因此总复杂度降低到O(n ^ 2)。

现在,即使我们得出证明文本合理的最低成本,我们也需要通过跟踪上面表达式中选择为最小值的j值来解决原始问题,以便以后可以使用相同的值进行打印找出合理的文字。这个想法是保留父指针。

希望这可以帮助您了解解决方案。以下是上述想法的简单实现。

 public class TextJustify {
    class IntPair {
        //The cost or badness
        final int x;

        //The index of word at the beginning of a line
        final int y;
        IntPair(int x, int y) {this.x=x;this.y=y;}
    }
    public List<String> fullJustify(String[] words, int L) {
        IntPair[] memo = new IntPair[words.length + 1];

        //Base case
        memo[words.length] = new IntPair(0, 0);


        for(int i = words.length - 1; i >= 0; i--) {
            int score = Integer.MAX_VALUE;
            int nextLineIndex = i + 1;
            for(int j = i + 1; j <= words.length; j++) {
                int badness = calcBadness(words, i, j, L);
                if(badness < 0 || badness == Integer.MAX_VALUE) break;
                int currScore = badness + memo[j].x;
                if(currScore < 0 || currScore == Integer.MAX_VALUE) break;
                if(score > currScore) {
                    score = currScore;
                    nextLineIndex = j;
                }
            }
            memo[i] = new IntPair(score, nextLineIndex);
        }

        List<String> result = new ArrayList<>();
        int i = 0;
        while(i < words.length) {
            String line = getLine(words, i, memo[i].y);
            result.add(line);
            i = memo[i].y;
        }
        return result;
    }

    private int calcBadness(String[] words, int start, int end, int width) {
        int length = 0;
        for(int i = start; i < end; i++) {
            length += words[i].length();
            if(length > width) return Integer.MAX_VALUE;
            length++;
        }
        length--;
        int temp = width - length;
        return temp * temp;
    }


    private String getLine(String[] words, int start, int end) {
        StringBuilder sb = new StringBuilder();
        for(int i = start; i < end - 1; i++) {
            sb.append(words[i] + " ");
        }
        sb.append(words[end - 1]);

        return sb.toString();
    }
  }