临时定义

在依赖类型编程中,有时类型检查器需要的类型会和我们编写的程序的类型不同 (在这里是它们的形式不同),然而尽管如此,它们在证明上依然是等价的。 例如,回想一下 parity 函数:

data Parity : Nat -> Type where
   Even : Parity (n + n)
   Odd  : Parity (S (n + n))

我们要把它实现为:

parity : (n:Nat) -> Parity n
parity Z     = Even {n=Z}
parity (S Z) = Odd {n=Z}
parity (S (S k)) with (parity k)
  parity (S (S (j + j)))     | Even = Even {n=S j}
  parity (S (S (S (j + j)))) | Odd  = Odd {n=S j}

它简单地指定了零为偶数,一为奇数,然后递归地说明了 k+2 的奇偶性与 k 相同。 显式地标出 n 是奇数还是偶数对于类型推断来说是必须的。然而,类型检查器却拒绝了它:

viewsbroken.idr:12:10: 在解析 ViewsBroken.parity 的右侧时:
    Parity (plus (S j) (S j))
与
    Parity (S (S (plus j j)))
的类型不匹配

具体为:
        plus (S j) (S j)
    与
        S (S (plus j j))
    的类型不匹配

类型检查器告诉我们 (j+1)+(j+1)2+j+j 无法规范化为相同的值。这是因为 plus 是在第一个参数上递归定义的,而在第二个值中,有一个后继符号作用在第二个参数上, 因此它无法帮助归约。这些值明显相等 — 不过我们要如何重写程序来修复此问题?

临时定义

临时定义(Provisional Definition) 允许我们推迟证明的细节以帮助解决此问题。 它主要有两个作用:

  • 成型(Prototyping) 时,它可在所有的证明细节结束前测试程序。
  • 阅读 程序时,推迟证明的细节通常会让过程更清晰,避免读者从底层算法中分心。

临时定义的写法和普通定义相同,只是它以 ?= 而非 = 引入右式。我们将 parity 定义为:

parity : (n:Nat) -> Parity n
parity Z = Even {n=Z}
parity (S Z) = Odd {n=Z}
parity (S (S k)) with (parity k)
  parity (S (S (j + j))) | Even ?= Even {n=S j}
  parity (S (S (S (j + j)))) | Odd ?= Odd {n=S j}

当写成这种形式时,Idris 不会报告类型错误,而是在定理中挖一个坑,以此来修正类型错误。 Idris 会告诉我们有两个证明义务,其名字根据模块和函数名生成:

*views> :m
Global holes:
        [views.parity_lemma_2,views.parity_lemma_1]

其中第一个坑的类型如下:

*views> :p views.parity_lemma_1

---------------------------------- (views.parity_lemma_1) --------
{hole0} : (j : Nat) -> (Parity (plus (S j) (S j))) -> Parity (S (S (plus j j)))

-views.parity_lemma_1>

它的两个参数为 j,一个是模式匹配作用域中的变量,另一个是 value, 它是我们在临时定义右侧给出的值。我们的目标是重写类型以便让我们能使用该值。 我们可以用 Prelude 中的以下定理来达到此目的:

plusSuccRightSucc : (left : Nat) -> (right : Nat) ->
  S (left + right) = left + (S right)

还要再用 compute 来展开 plus 的定义:

-views.parity_lemma_1> compute


---------------------------------- (views.parity_lemma_1) --------
{hole0} : (j : Nat) -> (Parity (S (plus j (S j)))) -> Parity (S (S (plus j j)))

在应用 intros 之后,我们有:

-views.parity_lemma_1> intros

  j : Nat
  value : Parity (S (plus j (S j)))
---------------------------------- (views.parity_lemma_1) --------
{hole2} : Parity (S (S (plus j j)))

接着,我们对称地对 jj 应用 plusSuccRightSucc 重写规则,它会给出:

-views.parity_lemma_1> rewrite sym (plusSuccRightSucc j j)

  j : Nat
  value : Parity (S (plus j (S j)))
---------------------------------- (views.parity_lemma_1) --------
{hole3} : Parity (S (plus j (S j)))

sym 是一个在库中定义的函数,它可以反转重写的顺序:

sym : l = r -> r = l
sym Refl = Refl

我们可以用 trivial 策略来完成此证明,它会在前提中找到 value。 第二个引理的证明方式完全相同。

现在我们可以在提示符中测试 with 规则:匹配中间值 一节中的 natToBin 了。数字 42 的二进制为 101010。其二进制数字以逆序表示:

*views> show (natToBin 42)
"[False, True, False, True, False, True]" : String

暂且相信

Idris 在编译程序前需要完成证明(尽管在提示符中求值可以无需详细证明)。然而有时候, 特别在成型时,不去完成证明反而更容易。在尝试证明它们之前就测试程序甚至可能会更好, 如果测试找到了一个错误,你就会知道最好不要花时间去证明某些东西了!

因此,Idris 提供了一个内建的强迫(coercion)函数,它允许我们使用类型错误的值:

believe_me : a -> b

显然,它的使用必须要非常小心。在成型时它非常有用,在断言外部代码(可能在外部的 C 库中)的性质时也是可以用的。使用了它的 views.parity_lemma_1 的「证明」为:

views.parity_lemma_2 = proof {
    intro;
    intro;
    exact believe_me value;
}

exact 策略允许我们为该证明提供一个确切的值。在本例中,我们断言给出的值是正确的。

示例:二进制数

我们在前面通过 Parity 视角实现了 Nat 到二进制数的转换。在这里, 我们会展示如何用同样的视角来实现已验证的二进制转换。我们首先在与其等价的 Nat 上索引二进制数。这是一种通用的模式,即将它的表示(这里为 Binary)与其含义 (这里为 Nat)关联起来:

data Binary : Nat -> Type where
   BEnd : Binary Z
   BO : Binary n -> Binary (n + n)
   BI : Binary n -> Binary (S (n + n))

BOBI 接受一个二进制数作为其参数并立即将它左移一位, 然后再加零或一作为新的最低位。索引 n + nS (n + n) 描述了左移后再相加的结果与该数值的意义相同。它会产生低位在前的表示。

现在,将 Nat 转换为二进制的函数在其类型中描述了结果二进制数为原始 Nat 的正确表示:

natToBin : (n:Nat) -> Binary n

Parity 视角让定义变得相当简单:把数折半其实就是进行一次右移,尽管我们需要在 Odd 的情况下使用临时定义:

natToBin : (n:Nat) -> Binary n
natToBin Z = BEnd
natToBin (S k) with (parity k)
   natToBin (S (j + j))     | Even  = BI (natToBin j)
   natToBin (S (S (j + j))) | Odd  ?= BO (natToBin (S j))

Odd 情况的问题与 parity 定义中的相同,其证明过程也一样:

natToBin_lemma_1 = proof {
    intro;
    intro;
    rewrite sym (plusSuccRightSucc j j);
    trivial;
}

最后,我们来实现一个 main 程序,它读取用户输入的整数并输出为二进制:

main : IO ()
main = do putStr "Enter a number: "
          x <- getLine
          print (natToBin (fromInteger (cast x)))

当然,为了能让它工作,我们需要为 Binary n 实现 Show

Show (Binary n) where
    show (BO x) = show x ++ "0"
    show (BI x) = show x ++ "1"
    show BEnd = ""