Update Go library to r60.
From-SVN: r178910
This commit is contained in:
parent
5548ca3540
commit
adb0401dac
718 changed files with 58911 additions and 30469 deletions
|
@ -806,7 +806,7 @@ proc go-gc-tests { } {
|
|||
$status $name
|
||||
} else {
|
||||
verbose -log $comp_output
|
||||
fali $name
|
||||
fail $name
|
||||
}
|
||||
file delete $ofile1 $ofile2 $output_file
|
||||
set runtests $hold_runtests
|
||||
|
|
|
@ -37,7 +37,7 @@ func main() {
|
|||
}
|
||||
fmt.Fprintln(out, `}`)
|
||||
}
|
||||
|
||||
|
||||
do(recv)
|
||||
do(send)
|
||||
do(recvOrder)
|
||||
|
@ -54,8 +54,8 @@ func run(t *template.Template, a interface{}, out io.Writer) {
|
|||
}
|
||||
}
|
||||
|
||||
type arg struct{
|
||||
def bool
|
||||
type arg struct {
|
||||
def bool
|
||||
nreset int
|
||||
}
|
||||
|
||||
|
@ -135,181 +135,180 @@ func main() {
|
|||
}
|
||||
`
|
||||
|
||||
func parse(s string) *template.Template {
|
||||
t := template.New(nil)
|
||||
t.SetDelims("〈", "〉")
|
||||
if err := t.Parse(s); err != nil {
|
||||
panic(s)
|
||||
func parse(name, s string) *template.Template {
|
||||
t, err := template.New(name).Parse(s)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("%q: %s", name, err))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
var recv = parse(`
|
||||
〈# Send n, receive it one way or another into x, check that they match.〉
|
||||
var recv = parse("recv", `
|
||||
{{/* Send n, receive it one way or another into x, check that they match. */}}
|
||||
c <- n
|
||||
〈.section Maybe〉
|
||||
{{if .Maybe}}
|
||||
x = <-c
|
||||
〈.or〉
|
||||
{{else}}
|
||||
select {
|
||||
〈# Blocking or non-blocking, before the receive.〉
|
||||
〈# The compiler implements two-case select where one is default with custom code,〉
|
||||
〈# so test the default branch both before and after the send.〉
|
||||
〈.section MaybeDefault〉
|
||||
{{/* Blocking or non-blocking, before the receive. */}}
|
||||
{{/* The compiler implements two-case select where one is default with custom code, */}}
|
||||
{{/* so test the default branch both before and after the send. */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Receive from c. Different cases are direct, indirect, :=, interface, and map assignment.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. */}}
|
||||
{{if .Maybe}}
|
||||
case x = <-c:
|
||||
〈.or〉〈.section Maybe〉
|
||||
{{else}}{{if .Maybe}}
|
||||
case *f(&x) = <-c:
|
||||
〈.or〉〈.section Maybe〉
|
||||
{{else}}{{if .Maybe}}
|
||||
case y := <-c:
|
||||
x = y
|
||||
〈.or〉〈.section Maybe〉
|
||||
{{else}}{{if .Maybe}}
|
||||
case i = <-c:
|
||||
x = i.(int)
|
||||
〈.or〉
|
||||
{{else}}
|
||||
case m[13] = <-c:
|
||||
x = m[13]
|
||||
〈.end〉〈.end〉〈.end〉〈.end〉
|
||||
〈# Blocking or non-blocking again, after the receive.〉
|
||||
〈.section MaybeDefault〉
|
||||
{{end}}{{end}}{{end}}{{end}}
|
||||
{{/* Blocking or non-blocking again, after the receive. */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Dummy send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case dummy <- 1:
|
||||
panic("dummy send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-dummy:
|
||||
panic("dummy receive")
|
||||
〈.end〉
|
||||
〈# Nil channel send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case nilch <- 1:
|
||||
panic("nilch send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-nilch:
|
||||
panic("nilch recv")
|
||||
〈.end〉
|
||||
{{end}}
|
||||
}
|
||||
〈.end〉
|
||||
{{end}}
|
||||
if x != n {
|
||||
die(x)
|
||||
}
|
||||
n++
|
||||
`)
|
||||
|
||||
var recvOrder = parse(`
|
||||
〈# Send n, receive it one way or another into x, check that they match.〉
|
||||
〈# Check order of operations along the way by calling functions that check〉
|
||||
〈# that the argument sequence is strictly increasing.〉
|
||||
var recvOrder = parse("recvOrder", `
|
||||
{{/* Send n, receive it one way or another into x, check that they match. */}}
|
||||
{{/* Check order of operations along the way by calling functions that check */}}
|
||||
{{/* that the argument sequence is strictly increasing. */}}
|
||||
order = 0
|
||||
c <- n
|
||||
〈.section Maybe〉
|
||||
〈# Outside of select, left-to-right rule applies.〉
|
||||
〈# (Inside select, assignment waits until case is chosen,〉
|
||||
〈# so right hand side happens before anything on left hand side.〉
|
||||
{{if .Maybe}}
|
||||
{{/* Outside of select, left-to-right rule applies. */}}
|
||||
{{/* (Inside select, assignment waits until case is chosen, */}}
|
||||
{{/* so right hand side happens before anything on left hand side. */}}
|
||||
*fp(&x, 1) = <-fc(c, 2)
|
||||
〈.or〉〈.section Maybe〉
|
||||
{{else}}{{if .Maybe}}
|
||||
m[fn(13, 1)] = <-fc(c, 2)
|
||||
x = m[13]
|
||||
〈.or〉
|
||||
{{else}}
|
||||
select {
|
||||
〈# Blocking or non-blocking, before the receive.〉
|
||||
〈# The compiler implements two-case select where one is default with custom code,〉
|
||||
〈# so test the default branch both before and after the send.〉
|
||||
〈.section MaybeDefault〉
|
||||
{{/* Blocking or non-blocking, before the receive. */}}
|
||||
{{/* The compiler implements two-case select where one is default with custom code, */}}
|
||||
{{/* so test the default branch both before and after the send. */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Receive from c. Different cases are direct, indirect, :=, interface, and map assignment.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. */}}
|
||||
{{if .Maybe}}
|
||||
case *fp(&x, 100) = <-fc(c, 1):
|
||||
〈.or〉〈.section Maybe〉
|
||||
{{else}}{{if .Maybe}}
|
||||
case y := <-fc(c, 1):
|
||||
x = y
|
||||
〈.or〉〈.section Maybe〉
|
||||
{{else}}{{if .Maybe}}
|
||||
case i = <-fc(c, 1):
|
||||
x = i.(int)
|
||||
〈.or〉
|
||||
{{else}}
|
||||
case m[fn(13, 100)] = <-fc(c, 1):
|
||||
x = m[13]
|
||||
〈.end〉〈.end〉〈.end〉
|
||||
〈# Blocking or non-blocking again, after the receive.〉
|
||||
〈.section MaybeDefault〉
|
||||
{{end}}{{end}}{{end}}
|
||||
{{/* Blocking or non-blocking again, after the receive. */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Dummy send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case fc(dummy, 2) <- fn(1, 3):
|
||||
panic("dummy send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-fc(dummy, 4):
|
||||
panic("dummy receive")
|
||||
〈.end〉
|
||||
〈# Nil channel send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case fc(nilch, 5) <- fn(1, 6):
|
||||
panic("nilch send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-fc(nilch, 7):
|
||||
panic("nilch recv")
|
||||
〈.end〉
|
||||
{{end}}
|
||||
}
|
||||
〈.end〉〈.end〉
|
||||
{{end}}{{end}}
|
||||
if x != n {
|
||||
die(x)
|
||||
}
|
||||
n++
|
||||
`)
|
||||
|
||||
var send = parse(`
|
||||
〈# Send n one way or another, receive it into x, check that they match.〉
|
||||
〈.section Maybe〉
|
||||
var send = parse("send", `
|
||||
{{/* Send n one way or another, receive it into x, check that they match. */}}
|
||||
{{if .Maybe}}
|
||||
c <- n
|
||||
〈.or〉
|
||||
{{else}}
|
||||
select {
|
||||
〈# Blocking or non-blocking, before the receive (same reason as in recv).〉
|
||||
〈.section MaybeDefault〉
|
||||
{{/* Blocking or non-blocking, before the receive (same reason as in recv). */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Send c <- n. No real special cases here, because no values come back〉
|
||||
〈# from the send operation.〉
|
||||
{{end}}
|
||||
{{/* Send c <- n. No real special cases here, because no values come back */}}
|
||||
{{/* from the send operation. */}}
|
||||
case c <- n:
|
||||
〈# Blocking or non-blocking.〉
|
||||
〈.section MaybeDefault〉
|
||||
{{/* Blocking or non-blocking. */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Dummy send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case dummy <- 1:
|
||||
panic("dummy send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-dummy:
|
||||
panic("dummy receive")
|
||||
〈.end〉
|
||||
〈# Nil channel send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case nilch <- 1:
|
||||
panic("nilch send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-nilch:
|
||||
panic("nilch recv")
|
||||
〈.end〉
|
||||
{{end}}
|
||||
}
|
||||
〈.end〉
|
||||
{{end}}
|
||||
x = <-c
|
||||
if x != n {
|
||||
die(x)
|
||||
|
@ -317,48 +316,48 @@ var send = parse(`
|
|||
n++
|
||||
`)
|
||||
|
||||
var sendOrder = parse(`
|
||||
〈# Send n one way or another, receive it into x, check that they match.〉
|
||||
〈# Check order of operations along the way by calling functions that check〉
|
||||
〈# that the argument sequence is strictly increasing.〉
|
||||
var sendOrder = parse("sendOrder", `
|
||||
{{/* Send n one way or another, receive it into x, check that they match. */}}
|
||||
{{/* Check order of operations along the way by calling functions that check */}}
|
||||
{{/* that the argument sequence is strictly increasing. */}}
|
||||
order = 0
|
||||
〈.section Maybe〉
|
||||
{{if .Maybe}}
|
||||
fc(c, 1) <- fn(n, 2)
|
||||
〈.or〉
|
||||
{{else}}
|
||||
select {
|
||||
〈# Blocking or non-blocking, before the receive (same reason as in recv).〉
|
||||
〈.section MaybeDefault〉
|
||||
{{/* Blocking or non-blocking, before the receive (same reason as in recv). */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Send c <- n. No real special cases here, because no values come back〉
|
||||
〈# from the send operation.〉
|
||||
{{end}}
|
||||
{{/* Send c <- n. No real special cases here, because no values come back */}}
|
||||
{{/* from the send operation. */}}
|
||||
case fc(c, 1) <- fn(n, 2):
|
||||
〈# Blocking or non-blocking.〉
|
||||
〈.section MaybeDefault〉
|
||||
{{/* Blocking or non-blocking. */}}
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
panic("nonblock")
|
||||
〈.end〉
|
||||
〈# Dummy send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case fc(dummy, 3) <- fn(1, 4):
|
||||
panic("dummy send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-fc(dummy, 5):
|
||||
panic("dummy receive")
|
||||
〈.end〉
|
||||
〈# Nil channel send, receive to keep compiler from optimizing select.〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
|
||||
{{if .Maybe}}
|
||||
case fc(nilch, 6) <- fn(1, 7):
|
||||
panic("nilch send")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-fc(nilch, 8):
|
||||
panic("nilch recv")
|
||||
〈.end〉
|
||||
{{end}}
|
||||
}
|
||||
〈.end〉
|
||||
{{end}}
|
||||
x = <-c
|
||||
if x != n {
|
||||
die(x)
|
||||
|
@ -366,49 +365,49 @@ var sendOrder = parse(`
|
|||
n++
|
||||
`)
|
||||
|
||||
var nonblock = parse(`
|
||||
var nonblock = parse("nonblock", `
|
||||
x = n
|
||||
〈# Test various combinations of non-blocking operations.〉
|
||||
〈# Receive assignments must not edit or even attempt to compute the address of the lhs.〉
|
||||
{{/* Test various combinations of non-blocking operations. */}}
|
||||
{{/* Receive assignments must not edit or even attempt to compute the address of the lhs. */}}
|
||||
select {
|
||||
〈.section MaybeDefault〉
|
||||
{{if .MaybeDefault}}
|
||||
default:
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case dummy <- 1:
|
||||
panic("dummy <- 1")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case nilch <- 1:
|
||||
panic("nilch <- 1")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-dummy:
|
||||
panic("<-dummy")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case x = <-dummy:
|
||||
panic("<-dummy x")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case **(**int)(nil) = <-dummy:
|
||||
panic("<-dummy (and didn't crash saving result!)")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case <-nilch:
|
||||
panic("<-nilch")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case x = <-nilch:
|
||||
panic("<-nilch x")
|
||||
〈.end〉
|
||||
〈.section Maybe〉
|
||||
{{end}}
|
||||
{{if .Maybe}}
|
||||
case **(**int)(nil) = <-nilch:
|
||||
panic("<-nilch (and didn't crash saving result!)")
|
||||
〈.end〉
|
||||
〈.section MustDefault〉
|
||||
{{end}}
|
||||
{{if .MustDefault}}
|
||||
default:
|
||||
〈.end〉
|
||||
{{end}}
|
||||
}
|
||||
if x != n {
|
||||
die(x)
|
||||
|
@ -466,7 +465,7 @@ func next() bool {
|
|||
}
|
||||
|
||||
// increment last choice sequence
|
||||
cp = len(choices)-1
|
||||
cp = len(choices) - 1
|
||||
for cp >= 0 && choices[cp].i == choices[cp].n-1 {
|
||||
cp--
|
||||
}
|
||||
|
@ -479,4 +478,3 @@ func next() bool {
|
|||
cp = 0
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ func Listen(x, y string) (T, string) {
|
|||
}
|
||||
|
||||
func (t T) Addr() os.Error {
|
||||
return os.ErrorString("stringer")
|
||||
return os.NewError("stringer")
|
||||
}
|
||||
|
||||
func (t T) Accept() (int, string) {
|
||||
|
@ -49,4 +49,3 @@ func Dial(x, y, z string) (int, string) {
|
|||
global <- 1
|
||||
return 0, ""
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ var chatty = flag.Bool("v", false, "chatty")
|
|||
var oldsys uint64
|
||||
|
||||
func bigger() {
|
||||
runtime.UpdateMemStats()
|
||||
if st := runtime.MemStats; oldsys < st.Sys {
|
||||
oldsys = st.Sys
|
||||
if *chatty {
|
||||
|
@ -31,7 +32,7 @@ func bigger() {
|
|||
}
|
||||
|
||||
func main() {
|
||||
runtime.GC() // clean up garbage from init
|
||||
runtime.GC() // clean up garbage from init
|
||||
runtime.MemProfileRate = 0 // disable profiler
|
||||
runtime.MemStats.Alloc = 0 // ignore stacks
|
||||
flag.Parse()
|
||||
|
@ -45,8 +46,10 @@ func main() {
|
|||
panic("fail")
|
||||
}
|
||||
b := runtime.Alloc(uintptr(j))
|
||||
runtime.UpdateMemStats()
|
||||
during := runtime.MemStats.Alloc
|
||||
runtime.Free(b)
|
||||
runtime.UpdateMemStats()
|
||||
if a := runtime.MemStats.Alloc; a != 0 {
|
||||
println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)")
|
||||
panic("fail")
|
||||
|
|
|
@ -42,6 +42,7 @@ func AllocAndFree(size, count int) {
|
|||
if *chatty {
|
||||
fmt.Printf("size=%d count=%d ...\n", size, count)
|
||||
}
|
||||
runtime.UpdateMemStats()
|
||||
n1 := stats.Alloc
|
||||
for i := 0; i < count; i++ {
|
||||
b[i] = runtime.Alloc(uintptr(size))
|
||||
|
@ -50,11 +51,13 @@ func AllocAndFree(size, count int) {
|
|||
println("lookup failed: got", base, n, "for", b[i])
|
||||
panic("fail")
|
||||
}
|
||||
if runtime.MemStats.Sys > 1e9 {
|
||||
runtime.UpdateMemStats()
|
||||
if stats.Sys > 1e9 {
|
||||
println("too much memory allocated")
|
||||
panic("fail")
|
||||
}
|
||||
}
|
||||
runtime.UpdateMemStats()
|
||||
n2 := stats.Alloc
|
||||
if *chatty {
|
||||
fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats)
|
||||
|
@ -72,6 +75,7 @@ func AllocAndFree(size, count int) {
|
|||
panic("fail")
|
||||
}
|
||||
runtime.Free(b[i])
|
||||
runtime.UpdateMemStats()
|
||||
if stats.Alloc != uint64(alloc-n) {
|
||||
println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n)
|
||||
panic("fail")
|
||||
|
@ -81,6 +85,7 @@ func AllocAndFree(size, count int) {
|
|||
panic("fail")
|
||||
}
|
||||
}
|
||||
runtime.UpdateMemStats()
|
||||
n4 := stats.Alloc
|
||||
|
||||
if *chatty {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
aea0ba6e5935
|
||||
504f4e9b079c
|
||||
|
||||
The first line of this file holds the Mercurial revision number of the
|
||||
last merge done from the master library sources.
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -12,12 +12,24 @@
|
|||
/* Define to 1 if you have the <inttypes.h> header file. */
|
||||
#undef HAVE_INTTYPES_H
|
||||
|
||||
/* Define to 1 if you have the <linux/filter.h> header file. */
|
||||
#undef HAVE_LINUX_FILTER_H
|
||||
|
||||
/* Define to 1 if you have the <linux/netlink.h> header file. */
|
||||
#undef HAVE_LINUX_NETLINK_H
|
||||
|
||||
/* Define to 1 if you have the <linux/rtnetlink.h> header file. */
|
||||
#undef HAVE_LINUX_RTNETLINK_H
|
||||
|
||||
/* Define to 1 if you have the <memory.h> header file. */
|
||||
#undef HAVE_MEMORY_H
|
||||
|
||||
/* Define to 1 if you have the `mincore' function. */
|
||||
#undef HAVE_MINCORE
|
||||
|
||||
/* Define to 1 if you have the <net/if.h> header file. */
|
||||
#undef HAVE_NET_IF_H
|
||||
|
||||
/* Define to 1 if the system has the type `off64_t'. */
|
||||
#undef HAVE_OFF64_T
|
||||
|
||||
|
@ -71,6 +83,9 @@
|
|||
/* Define to 1 if you have the <sys/select.h> header file. */
|
||||
#undef HAVE_SYS_SELECT_H
|
||||
|
||||
/* Define to 1 if you have the <sys/socket.h> header file. */
|
||||
#undef HAVE_SYS_SOCKET_H
|
||||
|
||||
/* Define to 1 if you have the <sys/stat.h> header file. */
|
||||
#undef HAVE_SYS_STAT_H
|
||||
|
||||
|
|
33
libgo/configure
vendored
33
libgo/configure
vendored
|
@ -617,7 +617,6 @@ USING_SPLIT_STACK_FALSE
|
|||
USING_SPLIT_STACK_TRUE
|
||||
SPLIT_STACK
|
||||
OSCFLAGS
|
||||
GO_DEBUG_PROC_REGS_OS_ARCH_FILE
|
||||
GO_SYSCALLS_SYSCALL_OS_ARCH_FILE
|
||||
GOARCH
|
||||
LIBGO_IS_X86_64_FALSE
|
||||
|
@ -10914,7 +10913,7 @@ else
|
|||
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
|
||||
lt_status=$lt_dlunknown
|
||||
cat > conftest.$ac_ext <<_LT_EOF
|
||||
#line 10917 "configure"
|
||||
#line 10916 "configure"
|
||||
#include "confdefs.h"
|
||||
|
||||
#if HAVE_DLFCN_H
|
||||
|
@ -11020,7 +11019,7 @@ else
|
|||
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
|
||||
lt_status=$lt_dlunknown
|
||||
cat > conftest.$ac_ext <<_LT_EOF
|
||||
#line 11023 "configure"
|
||||
#line 11022 "configure"
|
||||
#include "confdefs.h"
|
||||
|
||||
#if HAVE_DLFCN_H
|
||||
|
@ -13558,12 +13557,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
|
|||
fi
|
||||
|
||||
|
||||
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
|
||||
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
|
||||
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
|
||||
fi
|
||||
|
||||
|
||||
case "$target" in
|
||||
mips-sgi-irix6.5*)
|
||||
# IRIX 6 needs _XOPEN_SOURCE=500 for the XPG5 version of struct
|
||||
|
@ -14252,7 +14245,7 @@ no)
|
|||
;;
|
||||
esac
|
||||
|
||||
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h
|
||||
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h
|
||||
do :
|
||||
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
|
||||
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
|
||||
|
@ -14266,6 +14259,26 @@ fi
|
|||
|
||||
done
|
||||
|
||||
|
||||
for ac_header in linux/filter.h linux/netlink.h linux/rtnetlink.h
|
||||
do :
|
||||
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
|
||||
ac_fn_c_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "#ifdef HAVE_SYS_SOCKET_H
|
||||
#include <sys/socket.h>
|
||||
#endif
|
||||
|
||||
"
|
||||
eval as_val=\$$as_ac_Header
|
||||
if test "x$as_val" = x""yes; then :
|
||||
cat >>confdefs.h <<_ACEOF
|
||||
#define `$as_echo "HAVE_$ac_header" | $as_tr_cpp` 1
|
||||
_ACEOF
|
||||
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
|
||||
if test "$ac_cv_header_sys_mman_h" = yes; then
|
||||
HAVE_SYS_MMAN_H_TRUE=
|
||||
HAVE_SYS_MMAN_H_FALSE='#'
|
||||
|
|
|
@ -255,12 +255,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
|
|||
fi
|
||||
AC_SUBST(GO_SYSCALLS_SYSCALL_OS_ARCH_FILE)
|
||||
|
||||
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
|
||||
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
|
||||
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
|
||||
fi
|
||||
AC_SUBST(GO_DEBUG_PROC_REGS_OS_ARCH_FILE)
|
||||
|
||||
dnl Some targets need special flags to build sysinfo.go.
|
||||
case "$target" in
|
||||
mips-sgi-irix6.5*)
|
||||
|
@ -431,7 +425,14 @@ no)
|
|||
;;
|
||||
esac
|
||||
|
||||
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h)
|
||||
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h)
|
||||
|
||||
AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
|
||||
[#ifdef HAVE_SYS_SOCKET_H
|
||||
#include <sys/socket.h>
|
||||
#endif
|
||||
])
|
||||
|
||||
AM_CONDITIONAL(HAVE_SYS_MMAN_H, test "$ac_cv_header_sys_mman_h" = yes)
|
||||
|
||||
AC_CHECK_FUNCS(srandom random strerror_r strsignal wait4 mincore setenv)
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
HeaderError os.Error = os.ErrorString("invalid tar header")
|
||||
HeaderError = os.NewError("invalid tar header")
|
||||
)
|
||||
|
||||
// A Reader provides sequential access to the contents of a tar archive.
|
||||
|
|
|
@ -178,7 +178,6 @@ func TestPartialRead(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestIncrementalRead(t *testing.T) {
|
||||
test := gnuTarTest
|
||||
f, err := os.Open(test.file)
|
||||
|
|
|
@ -2,18 +2,10 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package zip provides support for reading ZIP archives.
|
||||
|
||||
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
|
||||
|
||||
This package does not support ZIP64 or disk spanning.
|
||||
*/
|
||||
package zip
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"hash"
|
||||
"hash/crc32"
|
||||
|
@ -24,9 +16,9 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
FormatError = os.NewError("not a valid zip file")
|
||||
UnsupportedMethod = os.NewError("unsupported compression algorithm")
|
||||
ChecksumError = os.NewError("checksum error")
|
||||
FormatError = os.NewError("zip: not a valid zip file")
|
||||
UnsupportedMethod = os.NewError("zip: unsupported compression algorithm")
|
||||
ChecksumError = os.NewError("zip: checksum error")
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
|
@ -44,15 +36,14 @@ type File struct {
|
|||
FileHeader
|
||||
zipr io.ReaderAt
|
||||
zipsize int64
|
||||
headerOffset uint32
|
||||
bodyOffset int64
|
||||
headerOffset int64
|
||||
}
|
||||
|
||||
func (f *File) hasDataDescriptor() bool {
|
||||
return f.Flags&0x8 != 0
|
||||
}
|
||||
|
||||
// OpenReader will open the Zip file specified by name and return a ReaderCloser.
|
||||
// OpenReader will open the Zip file specified by name and return a ReadCloser.
|
||||
func OpenReader(name string) (*ReadCloser, os.Error) {
|
||||
f, err := os.Open(name)
|
||||
if err != nil {
|
||||
|
@ -87,18 +78,33 @@ func (z *Reader) init(r io.ReaderAt, size int64) os.Error {
|
|||
return err
|
||||
}
|
||||
z.r = r
|
||||
z.File = make([]*File, end.directoryRecords)
|
||||
z.File = make([]*File, 0, end.directoryRecords)
|
||||
z.Comment = end.comment
|
||||
rs := io.NewSectionReader(r, 0, size)
|
||||
if _, err = rs.Seek(int64(end.directoryOffset), os.SEEK_SET); err != nil {
|
||||
return err
|
||||
}
|
||||
buf := bufio.NewReader(rs)
|
||||
for i := range z.File {
|
||||
z.File[i] = &File{zipr: r, zipsize: size}
|
||||
if err := readDirectoryHeader(z.File[i], buf); err != nil {
|
||||
|
||||
// The count of files inside a zip is truncated to fit in a uint16.
|
||||
// Gloss over this by reading headers until we encounter
|
||||
// a bad one, and then only report a FormatError or UnexpectedEOF if
|
||||
// the file count modulo 65536 is incorrect.
|
||||
for {
|
||||
f := &File{zipr: r, zipsize: size}
|
||||
err = readDirectoryHeader(f, buf)
|
||||
if err == FormatError || err == io.ErrUnexpectedEOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
z.File = append(z.File, f)
|
||||
}
|
||||
if uint16(len(z.File)) != end.directoryRecords {
|
||||
// Return the readDirectoryHeader error if we read
|
||||
// the wrong number of directory entries.
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -109,31 +115,22 @@ func (rc *ReadCloser) Close() os.Error {
|
|||
}
|
||||
|
||||
// Open returns a ReadCloser that provides access to the File's contents.
|
||||
// It is safe to Open and Read from files concurrently.
|
||||
func (f *File) Open() (rc io.ReadCloser, err os.Error) {
|
||||
off := int64(f.headerOffset)
|
||||
if f.bodyOffset == 0 {
|
||||
r := io.NewSectionReader(f.zipr, off, f.zipsize-off)
|
||||
if err = readFileHeader(f, r); err != nil {
|
||||
return
|
||||
}
|
||||
if f.bodyOffset, err = r.Seek(0, os.SEEK_CUR); err != nil {
|
||||
return
|
||||
}
|
||||
bodyOffset, err := f.findBodyOffset()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
size := int64(f.CompressedSize)
|
||||
if f.hasDataDescriptor() {
|
||||
if size == 0 {
|
||||
// permit SectionReader to see the rest of the file
|
||||
size = f.zipsize - (off + f.bodyOffset)
|
||||
} else {
|
||||
size += dataDescriptorLen
|
||||
}
|
||||
if size == 0 && f.hasDataDescriptor() {
|
||||
// permit SectionReader to see the rest of the file
|
||||
size = f.zipsize - (f.headerOffset + bodyOffset)
|
||||
}
|
||||
r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size)
|
||||
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
|
||||
switch f.Method {
|
||||
case 0: // store (no compression)
|
||||
case Store: // (no compression)
|
||||
rc = ioutil.NopCloser(r)
|
||||
case 8: // DEFLATE
|
||||
case Deflate:
|
||||
rc = flate.NewReader(r)
|
||||
default:
|
||||
err = UnsupportedMethod
|
||||
|
@ -170,90 +167,102 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) {
|
|||
|
||||
func (r *checksumReader) Close() os.Error { return r.rc.Close() }
|
||||
|
||||
func readFileHeader(f *File, r io.Reader) (err os.Error) {
|
||||
defer func() {
|
||||
if rerr, ok := recover().(os.Error); ok {
|
||||
err = rerr
|
||||
}
|
||||
}()
|
||||
var (
|
||||
signature uint32
|
||||
filenameLength uint16
|
||||
extraLength uint16
|
||||
)
|
||||
read(r, &signature)
|
||||
if signature != fileHeaderSignature {
|
||||
func readFileHeader(f *File, r io.Reader) os.Error {
|
||||
var b [fileHeaderLen]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
c := binary.LittleEndian
|
||||
if sig := c.Uint32(b[:4]); sig != fileHeaderSignature {
|
||||
return FormatError
|
||||
}
|
||||
read(r, &f.ReaderVersion)
|
||||
read(r, &f.Flags)
|
||||
read(r, &f.Method)
|
||||
read(r, &f.ModifiedTime)
|
||||
read(r, &f.ModifiedDate)
|
||||
read(r, &f.CRC32)
|
||||
read(r, &f.CompressedSize)
|
||||
read(r, &f.UncompressedSize)
|
||||
read(r, &filenameLength)
|
||||
read(r, &extraLength)
|
||||
f.Name = string(readByteSlice(r, filenameLength))
|
||||
f.Extra = readByteSlice(r, extraLength)
|
||||
return
|
||||
f.ReaderVersion = c.Uint16(b[4:6])
|
||||
f.Flags = c.Uint16(b[6:8])
|
||||
f.Method = c.Uint16(b[8:10])
|
||||
f.ModifiedTime = c.Uint16(b[10:12])
|
||||
f.ModifiedDate = c.Uint16(b[12:14])
|
||||
f.CRC32 = c.Uint32(b[14:18])
|
||||
f.CompressedSize = c.Uint32(b[18:22])
|
||||
f.UncompressedSize = c.Uint32(b[22:26])
|
||||
filenameLen := int(c.Uint16(b[26:28]))
|
||||
extraLen := int(c.Uint16(b[28:30]))
|
||||
d := make([]byte, filenameLen+extraLen)
|
||||
if _, err := io.ReadFull(r, d); err != nil {
|
||||
return err
|
||||
}
|
||||
f.Name = string(d[:filenameLen])
|
||||
f.Extra = d[filenameLen:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func readDirectoryHeader(f *File, r io.Reader) (err os.Error) {
|
||||
defer func() {
|
||||
if rerr, ok := recover().(os.Error); ok {
|
||||
err = rerr
|
||||
}
|
||||
}()
|
||||
var (
|
||||
signature uint32
|
||||
filenameLength uint16
|
||||
extraLength uint16
|
||||
commentLength uint16
|
||||
startDiskNumber uint16 // unused
|
||||
internalAttributes uint16 // unused
|
||||
externalAttributes uint32 // unused
|
||||
)
|
||||
read(r, &signature)
|
||||
if signature != directoryHeaderSignature {
|
||||
// findBodyOffset does the minimum work to verify the file has a header
|
||||
// and returns the file body offset.
|
||||
func (f *File) findBodyOffset() (int64, os.Error) {
|
||||
r := io.NewSectionReader(f.zipr, f.headerOffset, f.zipsize-f.headerOffset)
|
||||
var b [fileHeaderLen]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c := binary.LittleEndian
|
||||
if sig := c.Uint32(b[:4]); sig != fileHeaderSignature {
|
||||
return 0, FormatError
|
||||
}
|
||||
filenameLen := int(c.Uint16(b[26:28]))
|
||||
extraLen := int(c.Uint16(b[28:30]))
|
||||
return int64(fileHeaderLen + filenameLen + extraLen), nil
|
||||
}
|
||||
|
||||
// readDirectoryHeader attempts to read a directory header from r.
|
||||
// It returns io.ErrUnexpectedEOF if it cannot read a complete header,
|
||||
// and FormatError if it doesn't find a valid header signature.
|
||||
func readDirectoryHeader(f *File, r io.Reader) os.Error {
|
||||
var b [directoryHeaderLen]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
c := binary.LittleEndian
|
||||
if sig := c.Uint32(b[:4]); sig != directoryHeaderSignature {
|
||||
return FormatError
|
||||
}
|
||||
read(r, &f.CreatorVersion)
|
||||
read(r, &f.ReaderVersion)
|
||||
read(r, &f.Flags)
|
||||
read(r, &f.Method)
|
||||
read(r, &f.ModifiedTime)
|
||||
read(r, &f.ModifiedDate)
|
||||
read(r, &f.CRC32)
|
||||
read(r, &f.CompressedSize)
|
||||
read(r, &f.UncompressedSize)
|
||||
read(r, &filenameLength)
|
||||
read(r, &extraLength)
|
||||
read(r, &commentLength)
|
||||
read(r, &startDiskNumber)
|
||||
read(r, &internalAttributes)
|
||||
read(r, &externalAttributes)
|
||||
read(r, &f.headerOffset)
|
||||
f.Name = string(readByteSlice(r, filenameLength))
|
||||
f.Extra = readByteSlice(r, extraLength)
|
||||
f.Comment = string(readByteSlice(r, commentLength))
|
||||
return
|
||||
f.CreatorVersion = c.Uint16(b[4:6])
|
||||
f.ReaderVersion = c.Uint16(b[6:8])
|
||||
f.Flags = c.Uint16(b[8:10])
|
||||
f.Method = c.Uint16(b[10:12])
|
||||
f.ModifiedTime = c.Uint16(b[12:14])
|
||||
f.ModifiedDate = c.Uint16(b[14:16])
|
||||
f.CRC32 = c.Uint32(b[16:20])
|
||||
f.CompressedSize = c.Uint32(b[20:24])
|
||||
f.UncompressedSize = c.Uint32(b[24:28])
|
||||
filenameLen := int(c.Uint16(b[28:30]))
|
||||
extraLen := int(c.Uint16(b[30:32]))
|
||||
commentLen := int(c.Uint16(b[32:34]))
|
||||
// startDiskNumber := c.Uint16(b[34:36]) // Unused
|
||||
// internalAttributes := c.Uint16(b[36:38]) // Unused
|
||||
// externalAttributes := c.Uint32(b[38:42]) // Unused
|
||||
f.headerOffset = int64(c.Uint32(b[42:46]))
|
||||
d := make([]byte, filenameLen+extraLen+commentLen)
|
||||
if _, err := io.ReadFull(r, d); err != nil {
|
||||
return err
|
||||
}
|
||||
f.Name = string(d[:filenameLen])
|
||||
f.Extra = d[filenameLen : filenameLen+extraLen]
|
||||
f.Comment = string(d[filenameLen+extraLen:])
|
||||
return nil
|
||||
}
|
||||
|
||||
func readDataDescriptor(r io.Reader, f *File) (err os.Error) {
|
||||
defer func() {
|
||||
if rerr, ok := recover().(os.Error); ok {
|
||||
err = rerr
|
||||
}
|
||||
}()
|
||||
read(r, &f.CRC32)
|
||||
read(r, &f.CompressedSize)
|
||||
read(r, &f.UncompressedSize)
|
||||
return
|
||||
func readDataDescriptor(r io.Reader, f *File) os.Error {
|
||||
var b [dataDescriptorLen]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
c := binary.LittleEndian
|
||||
f.CRC32 = c.Uint32(b[:4])
|
||||
f.CompressedSize = c.Uint32(b[4:8])
|
||||
f.UncompressedSize = c.Uint32(b[8:12])
|
||||
return nil
|
||||
}
|
||||
|
||||
func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error) {
|
||||
func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, err os.Error) {
|
||||
// look for directoryEndSignature in the last 1k, then in the last 65k
|
||||
var b []byte
|
||||
for i, bLen := range []int64{1024, 65 * 1024} {
|
||||
|
@ -274,53 +283,29 @@ func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error)
|
|||
}
|
||||
|
||||
// read header into struct
|
||||
defer func() {
|
||||
if rerr, ok := recover().(os.Error); ok {
|
||||
err = rerr
|
||||
d = nil
|
||||
}
|
||||
}()
|
||||
br := bytes.NewBuffer(b[4:]) // skip over signature
|
||||
d = new(directoryEnd)
|
||||
read(br, &d.diskNbr)
|
||||
read(br, &d.dirDiskNbr)
|
||||
read(br, &d.dirRecordsThisDisk)
|
||||
read(br, &d.directoryRecords)
|
||||
read(br, &d.directorySize)
|
||||
read(br, &d.directoryOffset)
|
||||
read(br, &d.commentLen)
|
||||
d.comment = string(readByteSlice(br, d.commentLen))
|
||||
c := binary.LittleEndian
|
||||
d := new(directoryEnd)
|
||||
d.diskNbr = c.Uint16(b[4:6])
|
||||
d.dirDiskNbr = c.Uint16(b[6:8])
|
||||
d.dirRecordsThisDisk = c.Uint16(b[8:10])
|
||||
d.directoryRecords = c.Uint16(b[10:12])
|
||||
d.directorySize = c.Uint32(b[12:16])
|
||||
d.directoryOffset = c.Uint32(b[16:20])
|
||||
d.commentLen = c.Uint16(b[20:22])
|
||||
d.comment = string(b[22 : 22+int(d.commentLen)])
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func findSignatureInBlock(b []byte) int {
|
||||
const minSize = 4 + 2 + 2 + 2 + 2 + 4 + 4 + 2 // fixed part of header
|
||||
for i := len(b) - minSize; i >= 0; i-- {
|
||||
for i := len(b) - directoryEndLen; i >= 0; i-- {
|
||||
// defined from directoryEndSignature in struct.go
|
||||
if b[i] == 'P' && b[i+1] == 'K' && b[i+2] == 0x05 && b[i+3] == 0x06 {
|
||||
// n is length of comment
|
||||
n := int(b[i+minSize-2]) | int(b[i+minSize-1])<<8
|
||||
if n+minSize+i == len(b) {
|
||||
n := int(b[i+directoryEndLen-2]) | int(b[i+directoryEndLen-1])<<8
|
||||
if n+directoryEndLen+i == len(b) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func read(r io.Reader, data interface{}) {
|
||||
if err := binary.Read(r, binary.LittleEndian, data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func readByteSlice(r io.Reader, l uint16) []byte {
|
||||
b := make([]byte, l)
|
||||
if l == 0 {
|
||||
return b
|
||||
}
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ZipTest struct {
|
||||
|
@ -24,8 +25,19 @@ type ZipTestFile struct {
|
|||
Name string
|
||||
Content []byte // if blank, will attempt to compare against File
|
||||
File string // name of file to compare to (relative to testdata/)
|
||||
Mtime string // modified time in format "mm-dd-yy hh:mm:ss"
|
||||
}
|
||||
|
||||
// Caution: The Mtime values found for the test files should correspond to
|
||||
// the values listed with unzip -l <zipfile>. However, the values
|
||||
// listed by unzip appear to be off by some hours. When creating
|
||||
// fresh test files and testing them, this issue is not present.
|
||||
// The test files were created in Sydney, so there might be a time
|
||||
// zone issue. The time zone information does have to be encoded
|
||||
// somewhere, because otherwise unzip -l could not provide a different
|
||||
// time from what the archive/zip package provides, but there appears
|
||||
// to be no documentation about this.
|
||||
|
||||
var tests = []ZipTest{
|
||||
{
|
||||
Name: "test.zip",
|
||||
|
@ -34,10 +46,12 @@ var tests = []ZipTest{
|
|||
{
|
||||
Name: "test.txt",
|
||||
Content: []byte("This is a test text file.\n"),
|
||||
Mtime: "09-05-10 12:12:02",
|
||||
},
|
||||
{
|
||||
Name: "gophercolor16x16.png",
|
||||
File: "gophercolor16x16.png",
|
||||
Name: "gophercolor16x16.png",
|
||||
File: "gophercolor16x16.png",
|
||||
Mtime: "09-05-10 15:52:58",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -45,8 +59,9 @@ var tests = []ZipTest{
|
|||
Name: "r.zip",
|
||||
File: []ZipTestFile{
|
||||
{
|
||||
Name: "r/r.zip",
|
||||
File: "r.zip",
|
||||
Name: "r/r.zip",
|
||||
File: "r.zip",
|
||||
Mtime: "03-04-10 00:24:16",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -58,6 +73,7 @@ var tests = []ZipTest{
|
|||
{
|
||||
Name: "filename",
|
||||
Content: []byte("This is a test textfile.\n"),
|
||||
Mtime: "02-02-11 13:06:20",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -136,18 +152,36 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
|
|||
if f.Name != ft.Name {
|
||||
t.Errorf("name=%q, want %q", f.Name, ft.Name)
|
||||
}
|
||||
|
||||
mtime, err := time.Parse("01-02-06 15:04:05", ft.Mtime)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if got, want := f.Mtime_ns()/1e9, mtime.Seconds(); got != want {
|
||||
t.Errorf("%s: mtime=%s (%d); want %s (%d)", f.Name, time.SecondsToUTC(got), got, mtime, want)
|
||||
}
|
||||
|
||||
size0 := f.UncompressedSize
|
||||
|
||||
var b bytes.Buffer
|
||||
r, err := f.Open()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if size1 := f.UncompressedSize; size0 != size1 {
|
||||
t.Errorf("file %q changed f.UncompressedSize from %d to %d", f.Name, size0, size1)
|
||||
}
|
||||
|
||||
_, err = io.Copy(&b, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
r.Close()
|
||||
|
||||
var c []byte
|
||||
if len(ft.Content) != 0 {
|
||||
c = ft.Content
|
||||
|
@ -155,10 +189,12 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
|
|||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if b.Len() != len(c) {
|
||||
t.Errorf("%s: len=%d, want %d", f.Name, b.Len(), len(c))
|
||||
return
|
||||
}
|
||||
|
||||
for i, b := range b.Bytes() {
|
||||
if b != c[i] {
|
||||
t.Errorf("%s: content[%d]=%q want %q", f.Name, i, b, c[i])
|
||||
|
|
|
@ -1,9 +1,32 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package zip provides support for reading and writing ZIP archives.
|
||||
|
||||
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
|
||||
|
||||
This package does not support ZIP64 or disk spanning.
|
||||
*/
|
||||
package zip
|
||||
|
||||
import "os"
|
||||
import "time"
|
||||
|
||||
// Compression methods.
|
||||
const (
|
||||
Store uint16 = 0
|
||||
Deflate uint16 = 8
|
||||
)
|
||||
|
||||
const (
|
||||
fileHeaderSignature = 0x04034b50
|
||||
directoryHeaderSignature = 0x02014b50
|
||||
directoryEndSignature = 0x06054b50
|
||||
fileHeaderLen = 30 // + filename + extra
|
||||
directoryHeaderLen = 46 // + filename + extra + comment
|
||||
directoryEndLen = 22 // + comment
|
||||
dataDescriptorLen = 12
|
||||
)
|
||||
|
||||
|
@ -13,8 +36,8 @@ type FileHeader struct {
|
|||
ReaderVersion uint16
|
||||
Flags uint16
|
||||
Method uint16
|
||||
ModifiedTime uint16
|
||||
ModifiedDate uint16
|
||||
ModifiedTime uint16 // MS-DOS time
|
||||
ModifiedDate uint16 // MS-DOS date
|
||||
CRC32 uint32
|
||||
CompressedSize uint32
|
||||
UncompressedSize uint32
|
||||
|
@ -32,3 +55,37 @@ type directoryEnd struct {
|
|||
commentLen uint16
|
||||
comment string
|
||||
}
|
||||
|
||||
func recoverError(err *os.Error) {
|
||||
if e := recover(); e != nil {
|
||||
if osErr, ok := e.(os.Error); ok {
|
||||
*err = osErr
|
||||
return
|
||||
}
|
||||
panic(e)
|
||||
}
|
||||
}
|
||||
|
||||
// msDosTimeToTime converts an MS-DOS date and time into a time.Time.
|
||||
// The resolution is 2s.
|
||||
// See: http://msdn.microsoft.com/en-us/library/ms724247(v=VS.85).aspx
|
||||
func msDosTimeToTime(dosDate, dosTime uint16) time.Time {
|
||||
return time.Time{
|
||||
// date bits 0-4: day of month; 5-8: month; 9-15: years since 1980
|
||||
Year: int64(dosDate>>9 + 1980),
|
||||
Month: int(dosDate >> 5 & 0xf),
|
||||
Day: int(dosDate & 0x1f),
|
||||
|
||||
// time bits 0-4: second/2; 5-10: minute; 11-15: hour
|
||||
Hour: int(dosTime >> 11),
|
||||
Minute: int(dosTime >> 5 & 0x3f),
|
||||
Second: int(dosTime & 0x1f * 2),
|
||||
}
|
||||
}
|
||||
|
||||
// Mtime_ns returns the modified time in ns since epoch.
|
||||
// The resolution is 2s.
|
||||
func (h *FileHeader) Mtime_ns() int64 {
|
||||
t := msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
|
||||
return t.Seconds() * 1e9
|
||||
}
|
||||
|
|
244
libgo/go/archive/zip/writer.go
Normal file
244
libgo/go/archive/zip/writer.go
Normal file
|
@ -0,0 +1,244 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zip
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/flate"
|
||||
"encoding/binary"
|
||||
"hash"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// TODO(adg): support zip file comments
|
||||
// TODO(adg): support specifying deflate level
|
||||
|
||||
// Writer implements a zip file writer.
|
||||
type Writer struct {
|
||||
*countWriter
|
||||
dir []*header
|
||||
last *fileWriter
|
||||
closed bool
|
||||
}
|
||||
|
||||
type header struct {
|
||||
*FileHeader
|
||||
offset uint32
|
||||
}
|
||||
|
||||
// NewWriter returns a new Writer writing a zip file to w.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
return &Writer{countWriter: &countWriter{w: bufio.NewWriter(w)}}
|
||||
}
|
||||
|
||||
// Close finishes writing the zip file by writing the central directory.
|
||||
// It does not (and can not) close the underlying writer.
|
||||
func (w *Writer) Close() (err os.Error) {
|
||||
if w.last != nil && !w.last.closed {
|
||||
if err = w.last.close(); err != nil {
|
||||
return
|
||||
}
|
||||
w.last = nil
|
||||
}
|
||||
if w.closed {
|
||||
return os.NewError("zip: writer closed twice")
|
||||
}
|
||||
w.closed = true
|
||||
|
||||
defer recoverError(&err)
|
||||
|
||||
// write central directory
|
||||
start := w.count
|
||||
for _, h := range w.dir {
|
||||
write(w, uint32(directoryHeaderSignature))
|
||||
write(w, h.CreatorVersion)
|
||||
write(w, h.ReaderVersion)
|
||||
write(w, h.Flags)
|
||||
write(w, h.Method)
|
||||
write(w, h.ModifiedTime)
|
||||
write(w, h.ModifiedDate)
|
||||
write(w, h.CRC32)
|
||||
write(w, h.CompressedSize)
|
||||
write(w, h.UncompressedSize)
|
||||
write(w, uint16(len(h.Name)))
|
||||
write(w, uint16(len(h.Extra)))
|
||||
write(w, uint16(len(h.Comment)))
|
||||
write(w, uint16(0)) // disk number start
|
||||
write(w, uint16(0)) // internal file attributes
|
||||
write(w, uint32(0)) // external file attributes
|
||||
write(w, h.offset)
|
||||
writeBytes(w, []byte(h.Name))
|
||||
writeBytes(w, h.Extra)
|
||||
writeBytes(w, []byte(h.Comment))
|
||||
}
|
||||
end := w.count
|
||||
|
||||
// write end record
|
||||
write(w, uint32(directoryEndSignature))
|
||||
write(w, uint16(0)) // disk number
|
||||
write(w, uint16(0)) // disk number where directory starts
|
||||
write(w, uint16(len(w.dir))) // number of entries this disk
|
||||
write(w, uint16(len(w.dir))) // number of entries total
|
||||
write(w, uint32(end-start)) // size of directory
|
||||
write(w, uint32(start)) // start of directory
|
||||
write(w, uint16(0)) // size of comment
|
||||
|
||||
return w.w.(*bufio.Writer).Flush()
|
||||
}
|
||||
|
||||
// Create adds a file to the zip file using the provided name.
|
||||
// It returns a Writer to which the file contents should be written.
|
||||
// The file's contents must be written to the io.Writer before the next
|
||||
// call to Create, CreateHeader, or Close.
|
||||
func (w *Writer) Create(name string) (io.Writer, os.Error) {
|
||||
header := &FileHeader{
|
||||
Name: name,
|
||||
Method: Deflate,
|
||||
}
|
||||
return w.CreateHeader(header)
|
||||
}
|
||||
|
||||
// CreateHeader adds a file to the zip file using the provided FileHeader
|
||||
// for the file metadata.
|
||||
// It returns a Writer to which the file contents should be written.
|
||||
// The file's contents must be written to the io.Writer before the next
|
||||
// call to Create, CreateHeader, or Close.
|
||||
func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, os.Error) {
|
||||
if w.last != nil && !w.last.closed {
|
||||
if err := w.last.close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
fh.Flags |= 0x8 // we will write a data descriptor
|
||||
fh.CreatorVersion = 0x14
|
||||
fh.ReaderVersion = 0x14
|
||||
|
||||
fw := &fileWriter{
|
||||
zipw: w,
|
||||
compCount: &countWriter{w: w},
|
||||
crc32: crc32.NewIEEE(),
|
||||
}
|
||||
switch fh.Method {
|
||||
case Store:
|
||||
fw.comp = nopCloser{fw.compCount}
|
||||
case Deflate:
|
||||
fw.comp = flate.NewWriter(fw.compCount, 5)
|
||||
default:
|
||||
return nil, UnsupportedMethod
|
||||
}
|
||||
fw.rawCount = &countWriter{w: fw.comp}
|
||||
|
||||
h := &header{
|
||||
FileHeader: fh,
|
||||
offset: uint32(w.count),
|
||||
}
|
||||
w.dir = append(w.dir, h)
|
||||
fw.header = h
|
||||
|
||||
if err := writeHeader(w, fh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w.last = fw
|
||||
return fw, nil
|
||||
}
|
||||
|
||||
func writeHeader(w io.Writer, h *FileHeader) (err os.Error) {
|
||||
defer recoverError(&err)
|
||||
write(w, uint32(fileHeaderSignature))
|
||||
write(w, h.ReaderVersion)
|
||||
write(w, h.Flags)
|
||||
write(w, h.Method)
|
||||
write(w, h.ModifiedTime)
|
||||
write(w, h.ModifiedDate)
|
||||
write(w, h.CRC32)
|
||||
write(w, h.CompressedSize)
|
||||
write(w, h.UncompressedSize)
|
||||
write(w, uint16(len(h.Name)))
|
||||
write(w, uint16(len(h.Extra)))
|
||||
writeBytes(w, []byte(h.Name))
|
||||
writeBytes(w, h.Extra)
|
||||
return nil
|
||||
}
|
||||
|
||||
type fileWriter struct {
|
||||
*header
|
||||
zipw io.Writer
|
||||
rawCount *countWriter
|
||||
comp io.WriteCloser
|
||||
compCount *countWriter
|
||||
crc32 hash.Hash32
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (w *fileWriter) Write(p []byte) (int, os.Error) {
|
||||
if w.closed {
|
||||
return 0, os.NewError("zip: write to closed file")
|
||||
}
|
||||
w.crc32.Write(p)
|
||||
return w.rawCount.Write(p)
|
||||
}
|
||||
|
||||
func (w *fileWriter) close() (err os.Error) {
|
||||
if w.closed {
|
||||
return os.NewError("zip: file closed twice")
|
||||
}
|
||||
w.closed = true
|
||||
if err = w.comp.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// update FileHeader
|
||||
fh := w.header.FileHeader
|
||||
fh.CRC32 = w.crc32.Sum32()
|
||||
fh.CompressedSize = uint32(w.compCount.count)
|
||||
fh.UncompressedSize = uint32(w.rawCount.count)
|
||||
|
||||
// write data descriptor
|
||||
defer recoverError(&err)
|
||||
write(w.zipw, fh.CRC32)
|
||||
write(w.zipw, fh.CompressedSize)
|
||||
write(w.zipw, fh.UncompressedSize)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type countWriter struct {
|
||||
w io.Writer
|
||||
count int64
|
||||
}
|
||||
|
||||
func (w *countWriter) Write(p []byte) (int, os.Error) {
|
||||
n, err := w.w.Write(p)
|
||||
w.count += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
type nopCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (w nopCloser) Close() os.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func write(w io.Writer, data interface{}) {
|
||||
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeBytes(w io.Writer, b []byte) {
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if n != len(b) {
|
||||
panic(io.ErrShortWrite)
|
||||
}
|
||||
}
|
73
libgo/go/archive/zip/writer_test.go
Normal file
73
libgo/go/archive/zip/writer_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zip
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TODO(adg): a more sophisticated test suite
|
||||
|
||||
const testString = "Rabbits, guinea pigs, gophers, marsupial rats, and quolls."
|
||||
|
||||
func TestWriter(t *testing.T) {
|
||||
largeData := make([]byte, 1<<17)
|
||||
for i := range largeData {
|
||||
largeData[i] = byte(rand.Int())
|
||||
}
|
||||
|
||||
// write a zip file
|
||||
buf := new(bytes.Buffer)
|
||||
w := NewWriter(buf)
|
||||
testCreate(t, w, "foo", []byte(testString), Store)
|
||||
testCreate(t, w, "bar", largeData, Deflate)
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// read it back
|
||||
r, err := NewReader(sliceReaderAt(buf.Bytes()), int64(buf.Len()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testReadFile(t, r.File[0], []byte(testString))
|
||||
testReadFile(t, r.File[1], largeData)
|
||||
}
|
||||
|
||||
func testCreate(t *testing.T, w *Writer, name string, data []byte, method uint16) {
|
||||
header := &FileHeader{
|
||||
Name: name,
|
||||
Method: method,
|
||||
}
|
||||
f, err := w.CreateHeader(header)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = f.Write(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testReadFile(t *testing.T, f *File, data []byte) {
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
t.Fatal("opening:", err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(rc)
|
||||
if err != nil {
|
||||
t.Fatal("reading:", err)
|
||||
}
|
||||
err = rc.Close()
|
||||
if err != nil {
|
||||
t.Fatal("closing:", err)
|
||||
}
|
||||
if !bytes.Equal(b, data) {
|
||||
t.Errorf("File contents %q, want %q", b, data)
|
||||
}
|
||||
}
|
57
libgo/go/archive/zip/zip_test.go
Normal file
57
libgo/go/archive/zip/zip_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Tests that involve both reading and writing.
|
||||
|
||||
package zip
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type stringReaderAt string
|
||||
|
||||
func (s stringReaderAt) ReadAt(p []byte, off int64) (n int, err os.Error) {
|
||||
if off >= int64(len(s)) {
|
||||
return 0, os.EOF
|
||||
}
|
||||
n = copy(p, s[off:])
|
||||
return
|
||||
}
|
||||
|
||||
func TestOver65kFiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Logf("slow test; skipping")
|
||||
return
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
w := NewWriter(buf)
|
||||
const nFiles = (1 << 16) + 42
|
||||
for i := 0; i < nFiles; i++ {
|
||||
_, err := w.Create(fmt.Sprintf("%d.dat", i))
|
||||
if err != nil {
|
||||
t.Fatalf("creating file %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("Writer.Close: %v", err)
|
||||
}
|
||||
rat := stringReaderAt(buf.String())
|
||||
zr, err := NewReader(rat, int64(len(rat)))
|
||||
if err != nil {
|
||||
t.Fatalf("NewReader: %v", err)
|
||||
}
|
||||
if got := len(zr.File); got != nFiles {
|
||||
t.Fatalf("File contains %d files, want %d", got, nFiles)
|
||||
}
|
||||
for i := 0; i < nFiles; i++ {
|
||||
want := fmt.Sprintf("%d.dat", i)
|
||||
if zr.File[i].Name != want {
|
||||
t.Fatalf("File(%d) = %q, want %q", i, zr.File[i].Name, want)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ package asn1
|
|||
// everything by any means.
|
||||
|
||||
import (
|
||||
"big"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -88,6 +89,27 @@ func parseInt(bytes []byte) (int, os.Error) {
|
|||
return int(ret64), nil
|
||||
}
|
||||
|
||||
var bigOne = big.NewInt(1)
|
||||
|
||||
// parseBigInt treats the given bytes as a big-endian, signed integer and returns
|
||||
// the result.
|
||||
func parseBigInt(bytes []byte) *big.Int {
|
||||
ret := new(big.Int)
|
||||
if len(bytes) > 0 && bytes[0]&0x80 == 0x80 {
|
||||
// This is a negative number.
|
||||
notBytes := make([]byte, len(bytes))
|
||||
for i := range notBytes {
|
||||
notBytes[i] = ^bytes[i]
|
||||
}
|
||||
ret.SetBytes(notBytes)
|
||||
ret.Add(ret, bigOne)
|
||||
ret.Neg(ret)
|
||||
return ret
|
||||
}
|
||||
ret.SetBytes(bytes)
|
||||
return ret
|
||||
}
|
||||
|
||||
// BIT STRING
|
||||
|
||||
// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
|
||||
|
@ -127,7 +149,7 @@ func (b BitString) RightAlign() []byte {
|
|||
return a
|
||||
}
|
||||
|
||||
// parseBitString parses an ASN.1 bit string from the given byte array and returns it.
|
||||
// parseBitString parses an ASN.1 bit string from the given byte slice and returns it.
|
||||
func parseBitString(bytes []byte) (ret BitString, err os.Error) {
|
||||
if len(bytes) == 0 {
|
||||
err = SyntaxError{"zero length BIT STRING"}
|
||||
|
@ -164,9 +186,9 @@ func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and
|
||||
// returns it. An object identifer is a sequence of variable length integers
|
||||
// that are assigned in a hierarachy.
|
||||
// parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and
|
||||
// returns it. An object identifier is a sequence of variable length integers
|
||||
// that are assigned in a hierarchy.
|
||||
func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
|
||||
if len(bytes) == 0 {
|
||||
err = SyntaxError{"zero length OBJECT IDENTIFIER"}
|
||||
|
@ -198,14 +220,13 @@ func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
|
|||
// An Enumerated is represented as a plain int.
|
||||
type Enumerated int
|
||||
|
||||
|
||||
// FLAG
|
||||
|
||||
// A Flag accepts any data and is set to true if present.
|
||||
type Flag bool
|
||||
|
||||
// parseBase128Int parses a base-128 encoded int from the given offset in the
|
||||
// given byte array. It returns the value and the new offset.
|
||||
// given byte slice. It returns the value and the new offset.
|
||||
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
|
||||
offset = initOffset
|
||||
for shifted := 0; offset < len(bytes); shifted++ {
|
||||
|
@ -237,7 +258,7 @@ func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
// parseGeneralizedTime parses the GeneralizedTime from the given byte array
|
||||
// parseGeneralizedTime parses the GeneralizedTime from the given byte slice
|
||||
// and returns the resulting time.
|
||||
func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
|
||||
return time.Parse("20060102150405Z0700", string(bytes))
|
||||
|
@ -269,7 +290,7 @@ func isPrintable(b byte) bool {
|
|||
b == ':' ||
|
||||
b == '=' ||
|
||||
b == '?' ||
|
||||
// This is techincally not allowed in a PrintableString.
|
||||
// This is technically not allowed in a PrintableString.
|
||||
// However, x509 certificates with wildcard strings don't
|
||||
// always use the correct string type so we permit it.
|
||||
b == '*'
|
||||
|
@ -278,7 +299,7 @@ func isPrintable(b byte) bool {
|
|||
// IA5String
|
||||
|
||||
// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
|
||||
// byte array and returns it.
|
||||
// byte slice and returns it.
|
||||
func parseIA5String(bytes []byte) (ret string, err os.Error) {
|
||||
for _, b := range bytes {
|
||||
if b >= 0x80 {
|
||||
|
@ -293,11 +314,19 @@ func parseIA5String(bytes []byte) (ret string, err os.Error) {
|
|||
// T61String
|
||||
|
||||
// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
|
||||
// byte array and returns it.
|
||||
// byte slice and returns it.
|
||||
func parseT61String(bytes []byte) (ret string, err os.Error) {
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// UTF8String
|
||||
|
||||
// parseUTF8String parses a ASN.1 UTF8String (raw UTF-8) from the given byte
|
||||
// array and returns it.
|
||||
func parseUTF8String(bytes []byte) (ret string, err os.Error) {
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// A RawValue represents an undecoded ASN.1 object.
|
||||
type RawValue struct {
|
||||
Class, Tag int
|
||||
|
@ -314,7 +343,7 @@ type RawContent []byte
|
|||
// Tagging
|
||||
|
||||
// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
|
||||
// into a byte array. It returns the parsed data and the new offset. SET and
|
||||
// into a byte slice. It returns the parsed data and the new offset. SET and
|
||||
// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
|
||||
// don't distinguish between ordered and unordered objects in this code.
|
||||
func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
|
||||
|
@ -371,7 +400,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i
|
|||
}
|
||||
|
||||
// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
|
||||
// a number of ASN.1 values from the given byte array and returns them as a
|
||||
// a number of ASN.1 values from the given byte slice and returns them as a
|
||||
// slice of Go values of the given type.
|
||||
func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) {
|
||||
expectedTag, compoundType, ok := getUniversalType(elemType)
|
||||
|
@ -425,6 +454,7 @@ var (
|
|||
timeType = reflect.TypeOf(&time.Time{})
|
||||
rawValueType = reflect.TypeOf(RawValue{})
|
||||
rawContentsType = reflect.TypeOf(RawContent(nil))
|
||||
bigIntType = reflect.TypeOf(new(big.Int))
|
||||
)
|
||||
|
||||
// invalidLength returns true iff offset + length > sliceLength, or if the
|
||||
|
@ -433,7 +463,7 @@ func invalidLength(offset, length, sliceLength int) bool {
|
|||
return offset+length < offset || offset+length > sliceLength
|
||||
}
|
||||
|
||||
// parseField is the main parsing function. Given a byte array and an offset
|
||||
// parseField is the main parsing function. Given a byte slice and an offset
|
||||
// into the array, it will try to parse a suitable ASN.1 value out and store it
|
||||
// in the given Value.
|
||||
func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
|
||||
|
@ -550,16 +580,15 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
|
|||
}
|
||||
}
|
||||
|
||||
// Special case for strings: PrintableString and IA5String both map to
|
||||
// the Go type string. getUniversalType returns the tag for
|
||||
// PrintableString when it sees a string so, if we see an IA5String on
|
||||
// the wire, we change the universal type to match.
|
||||
if universalTag == tagPrintableString && t.tag == tagIA5String {
|
||||
universalTag = tagIA5String
|
||||
}
|
||||
// Likewise for GeneralString
|
||||
if universalTag == tagPrintableString && t.tag == tagGeneralString {
|
||||
universalTag = tagGeneralString
|
||||
// Special case for strings: all the ASN.1 string types map to the Go
|
||||
// type string. getUniversalType returns the tag for PrintableString
|
||||
// when it sees a string, so if we see a different string type on the
|
||||
// wire, we change the universal type to match.
|
||||
if universalTag == tagPrintableString {
|
||||
switch t.tag {
|
||||
case tagIA5String, tagGeneralString, tagT61String, tagUTF8String:
|
||||
universalTag = t.tag
|
||||
}
|
||||
}
|
||||
|
||||
// Special case for time: UTCTime and GeneralizedTime both map to the
|
||||
|
@ -639,6 +668,10 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
|
|||
case flagType:
|
||||
v.SetBool(true)
|
||||
return
|
||||
case bigIntType:
|
||||
parsedInt := parseBigInt(innerBytes)
|
||||
v.Set(reflect.ValueOf(parsedInt))
|
||||
return
|
||||
}
|
||||
switch val := v; val.Kind() {
|
||||
case reflect.Bool:
|
||||
|
@ -648,23 +681,21 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
|
|||
}
|
||||
err = err1
|
||||
return
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
switch val.Type().Kind() {
|
||||
case reflect.Int:
|
||||
parsedInt, err1 := parseInt(innerBytes)
|
||||
if err1 == nil {
|
||||
val.SetInt(int64(parsedInt))
|
||||
}
|
||||
err = err1
|
||||
return
|
||||
case reflect.Int64:
|
||||
parsedInt, err1 := parseInt64(innerBytes)
|
||||
if err1 == nil {
|
||||
val.SetInt(parsedInt)
|
||||
}
|
||||
err = err1
|
||||
return
|
||||
case reflect.Int, reflect.Int32:
|
||||
parsedInt, err1 := parseInt(innerBytes)
|
||||
if err1 == nil {
|
||||
val.SetInt(int64(parsedInt))
|
||||
}
|
||||
err = err1
|
||||
return
|
||||
case reflect.Int64:
|
||||
parsedInt, err1 := parseInt64(innerBytes)
|
||||
if err1 == nil {
|
||||
val.SetInt(parsedInt)
|
||||
}
|
||||
err = err1
|
||||
return
|
||||
// TODO(dfc) Add support for the remaining integer types
|
||||
case reflect.Struct:
|
||||
structType := fieldType
|
||||
|
||||
|
@ -680,7 +711,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
|
|||
if i == 0 && field.Type == rawContentsType {
|
||||
continue
|
||||
}
|
||||
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag))
|
||||
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag.Get("asn1")))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -711,6 +742,8 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
|
|||
v, err = parseIA5String(innerBytes)
|
||||
case tagT61String:
|
||||
v, err = parseT61String(innerBytes)
|
||||
case tagUTF8String:
|
||||
v, err = parseUTF8String(innerBytes)
|
||||
case tagGeneralString:
|
||||
// GeneralString is specified in ISO-2022/ECMA-35,
|
||||
// A brief review suggests that it includes structures
|
||||
|
@ -725,7 +758,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
|
|||
}
|
||||
return
|
||||
}
|
||||
err = StructuralError{"unknown Go type"}
|
||||
err = StructuralError{"unsupported: " + v.Type().String()}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -752,7 +785,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
|
|||
// Because Unmarshal uses the reflect package, the structs
|
||||
// being written to must use upper case field names.
|
||||
//
|
||||
// An ASN.1 INTEGER can be written to an int or int64.
|
||||
// An ASN.1 INTEGER can be written to an int, int32 or int64.
|
||||
// If the encoded value does not fit in the Go type,
|
||||
// Unmarshal returns a parse error.
|
||||
//
|
||||
|
|
|
@ -42,6 +42,64 @@ func TestParseInt64(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type int32Test struct {
|
||||
in []byte
|
||||
ok bool
|
||||
out int32
|
||||
}
|
||||
|
||||
var int32TestData = []int32Test{
|
||||
{[]byte{0x00}, true, 0},
|
||||
{[]byte{0x7f}, true, 127},
|
||||
{[]byte{0x00, 0x80}, true, 128},
|
||||
{[]byte{0x01, 0x00}, true, 256},
|
||||
{[]byte{0x80}, true, -128},
|
||||
{[]byte{0xff, 0x7f}, true, -129},
|
||||
{[]byte{0xff, 0xff, 0xff, 0xff}, true, -1},
|
||||
{[]byte{0xff}, true, -1},
|
||||
{[]byte{0x80, 0x00, 0x00, 0x00}, true, -2147483648},
|
||||
{[]byte{0x80, 0x00, 0x00, 0x00, 0x00}, false, 0},
|
||||
}
|
||||
|
||||
func TestParseInt32(t *testing.T) {
|
||||
for i, test := range int32TestData {
|
||||
ret, err := parseInt(test.in)
|
||||
if (err == nil) != test.ok {
|
||||
t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok)
|
||||
}
|
||||
if test.ok && int32(ret) != test.out {
|
||||
t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var bigIntTests = []struct {
|
||||
in []byte
|
||||
base10 string
|
||||
}{
|
||||
{[]byte{0xff}, "-1"},
|
||||
{[]byte{0x00}, "0"},
|
||||
{[]byte{0x01}, "1"},
|
||||
{[]byte{0x00, 0xff}, "255"},
|
||||
{[]byte{0xff, 0x00}, "-256"},
|
||||
{[]byte{0x01, 0x00}, "256"},
|
||||
}
|
||||
|
||||
func TestParseBigInt(t *testing.T) {
|
||||
for i, test := range bigIntTests {
|
||||
ret := parseBigInt(test.in)
|
||||
if ret.String() != test.base10 {
|
||||
t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
|
||||
}
|
||||
fw := newForkableWriter()
|
||||
marshalBigInt(fw, ret)
|
||||
result := fw.Bytes()
|
||||
if !bytes.Equal(result, test.in) {
|
||||
t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type bitStringTest struct {
|
||||
in []byte
|
||||
ok bool
|
||||
|
@ -148,10 +206,10 @@ type timeTest struct {
|
|||
}
|
||||
|
||||
var utcTestData = []timeTest{
|
||||
{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, -7 * 60 * 60, ""}},
|
||||
{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 7*60*60 + 30*60, ""}},
|
||||
{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, "UTC"}},
|
||||
{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, "UTC"}},
|
||||
{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, -7 * 60 * 60, ""}},
|
||||
{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, 7*60*60 + 30*60, ""}},
|
||||
{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, 0, "UTC"}},
|
||||
{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, 0, "UTC"}},
|
||||
{"a10506234540Z", false, nil},
|
||||
{"91a506234540Z", false, nil},
|
||||
{"9105a6234540Z", false, nil},
|
||||
|
@ -177,10 +235,10 @@ func TestUTCTime(t *testing.T) {
|
|||
}
|
||||
|
||||
var generalizedTimeTestData = []timeTest{
|
||||
{"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, "UTC"}},
|
||||
{"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 0, "UTC"}},
|
||||
{"20100102030405", false, nil},
|
||||
{"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 6*60*60 + 7*60, ""}},
|
||||
{"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, -6*60*60 - 7*60, ""}},
|
||||
{"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 6*60*60 + 7*60, ""}},
|
||||
{"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, -6*60*60 - 7*60, ""}},
|
||||
}
|
||||
|
||||
func TestGeneralizedTime(t *testing.T) {
|
||||
|
@ -272,11 +330,11 @@ type TestObjectIdentifierStruct struct {
|
|||
}
|
||||
|
||||
type TestContextSpecificTags struct {
|
||||
A int "tag:1"
|
||||
A int `asn1:"tag:1"`
|
||||
}
|
||||
|
||||
type TestContextSpecificTags2 struct {
|
||||
A int "explicit,tag:1"
|
||||
A int `asn1:"explicit,tag:1"`
|
||||
B int
|
||||
}
|
||||
|
||||
|
@ -326,7 +384,7 @@ type Certificate struct {
|
|||
}
|
||||
|
||||
type TBSCertificate struct {
|
||||
Version int "optional,explicit,default:0,tag:0"
|
||||
Version int `asn1:"optional,explicit,default:0,tag:0"`
|
||||
SerialNumber RawValue
|
||||
SignatureAlgorithm AlgorithmIdentifier
|
||||
Issuer RDNSequence
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
// ASN.1 objects have metadata preceeding them:
|
||||
// ASN.1 objects have metadata preceding them:
|
||||
// the tag: the type of the object
|
||||
// a flag denoting if this object is compound or not
|
||||
// the class type: the namespace of the tag
|
||||
|
@ -25,6 +25,7 @@ const (
|
|||
tagOctetString = 4
|
||||
tagOID = 6
|
||||
tagEnum = 10
|
||||
tagUTF8String = 12
|
||||
tagSequence = 16
|
||||
tagSet = 17
|
||||
tagPrintableString = 19
|
||||
|
@ -83,7 +84,7 @@ type fieldParameters struct {
|
|||
// parseFieldParameters will parse it into a fieldParameters structure,
|
||||
// ignoring unknown parts of the string.
|
||||
func parseFieldParameters(str string) (ret fieldParameters) {
|
||||
for _, part := range strings.Split(str, ",", -1) {
|
||||
for _, part := range strings.Split(str, ",") {
|
||||
switch {
|
||||
case part == "optional":
|
||||
ret.optional = true
|
||||
|
@ -132,6 +133,8 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
|
|||
return tagUTCTime, false, true
|
||||
case enumeratedType:
|
||||
return tagEnum, false, true
|
||||
case bigIntType:
|
||||
return tagInteger, false, true
|
||||
}
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package asn1
|
||||
|
||||
import (
|
||||
"big"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -125,6 +126,43 @@ func int64Length(i int64) (numBytes int) {
|
|||
return
|
||||
}
|
||||
|
||||
func marshalBigInt(out *forkableWriter, n *big.Int) (err os.Error) {
|
||||
if n.Sign() < 0 {
|
||||
// A negative number has to be converted to two's-complement
|
||||
// form. So we'll subtract 1 and invert. If the
|
||||
// most-significant-bit isn't set then we'll need to pad the
|
||||
// beginning with 0xff in order to keep the number negative.
|
||||
nMinus1 := new(big.Int).Neg(n)
|
||||
nMinus1.Sub(nMinus1, bigOne)
|
||||
bytes := nMinus1.Bytes()
|
||||
for i := range bytes {
|
||||
bytes[i] ^= 0xff
|
||||
}
|
||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
|
||||
err = out.WriteByte(0xff)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
_, err = out.Write(bytes)
|
||||
} else if n.Sign() == 0 {
|
||||
// Zero is written as a single 0 zero rather than no bytes.
|
||||
err = out.WriteByte(0x00)
|
||||
} else {
|
||||
bytes := n.Bytes()
|
||||
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
|
||||
// We'll have to pad this with 0x00 in order to stop it
|
||||
// looking like a negative number.
|
||||
err = out.WriteByte(0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
_, err = out.Write(bytes)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func marshalLength(out *forkableWriter, i int) (err os.Error) {
|
||||
n := lengthLength(i)
|
||||
|
||||
|
@ -334,6 +372,8 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
|
|||
return marshalBitString(out, value.Interface().(BitString))
|
||||
case objectIdentifierType:
|
||||
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
|
||||
case bigIntType:
|
||||
return marshalBigInt(out, value.Interface().(*big.Int))
|
||||
}
|
||||
|
||||
switch v := value; v.Kind() {
|
||||
|
@ -351,7 +391,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
|
|||
startingField := 0
|
||||
|
||||
// If the first element of the structure is a non-empty
|
||||
// RawContents, then we don't bother serialising the rest.
|
||||
// RawContents, then we don't bother serializing the rest.
|
||||
if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
|
||||
s := v.Field(0)
|
||||
if s.Len() > 0 {
|
||||
|
@ -361,7 +401,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
|
|||
}
|
||||
/* The RawContents will contain the tag and
|
||||
* length fields but we'll also be writing
|
||||
* those outselves, so we strip them out of
|
||||
* those ourselves, so we strip them out of
|
||||
* bytes */
|
||||
_, err = out.Write(stripTagAndLength(bytes))
|
||||
return
|
||||
|
@ -373,7 +413,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
|
|||
for i := startingField; i < t.NumField(); i++ {
|
||||
var pre *forkableWriter
|
||||
pre, out = out.fork()
|
||||
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag))
|
||||
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -418,6 +458,10 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
|
|||
return marshalField(out, v.Elem(), params)
|
||||
}
|
||||
|
||||
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
|
||||
return
|
||||
}
|
||||
|
||||
if v.Type() == rawValueType {
|
||||
rv := v.Interface().(RawValue)
|
||||
err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
|
||||
|
@ -428,10 +472,6 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
|
|||
return
|
||||
}
|
||||
|
||||
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
|
||||
return
|
||||
}
|
||||
|
||||
tag, isCompound, ok := getUniversalType(v.Type())
|
||||
if !ok {
|
||||
err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
|
||||
|
|
|
@ -30,19 +30,23 @@ type rawContentsStruct struct {
|
|||
}
|
||||
|
||||
type implicitTagTest struct {
|
||||
A int "implicit,tag:5"
|
||||
A int `asn1:"implicit,tag:5"`
|
||||
}
|
||||
|
||||
type explicitTagTest struct {
|
||||
A int "explicit,tag:5"
|
||||
A int `asn1:"explicit,tag:5"`
|
||||
}
|
||||
|
||||
type ia5StringTest struct {
|
||||
A string "ia5"
|
||||
A string `asn1:"ia5"`
|
||||
}
|
||||
|
||||
type printableStringTest struct {
|
||||
A string "printable"
|
||||
A string `asn1:"printable"`
|
||||
}
|
||||
|
||||
type optionalRawValueTest struct {
|
||||
A RawValue `asn1:"optional"`
|
||||
}
|
||||
|
||||
type testSET []int
|
||||
|
@ -102,6 +106,7 @@ var marshalTests = []marshalTest{
|
|||
"7878787878787878787878787878787878787878787878787878787878787878",
|
||||
},
|
||||
{ia5StringTest{"test"}, "3006160474657374"},
|
||||
{optionalRawValueTest{}, "3000"},
|
||||
{printableStringTest{"test"}, "3006130474657374"},
|
||||
{printableStringTest{"test*"}, "30071305746573742a"},
|
||||
{rawContentsStruct{nil, 64}, "3003020140"},
|
||||
|
|
|
@ -27,7 +27,6 @@ const (
|
|||
_M2 = _B2 - 1 // half digit mask
|
||||
)
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Elementary operations on words
|
||||
//
|
||||
|
@ -43,7 +42,6 @@ func addWW_g(x, y, c Word) (z1, z0 Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// z1<<_W + z0 = x-y-c, with c == 0 or 1
|
||||
func subWW_g(x, y, c Word) (z1, z0 Word) {
|
||||
yc := y + c
|
||||
|
@ -54,7 +52,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// z1<<_W + z0 = x*y
|
||||
func mulWW(x, y Word) (z1, z0 Word) { return mulWW_g(x, y) }
|
||||
// Adapted from Warren, Hacker's Delight, p. 132.
|
||||
|
@ -73,7 +70,6 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// z1<<_W + z0 = x*y + c
|
||||
func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
|
||||
z1, zz0 := mulWW(x, y)
|
||||
|
@ -83,7 +79,6 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// Length of x in bits.
|
||||
func bitLen(x Word) (n int) {
|
||||
for ; x >= 0x100; x >>= 8 {
|
||||
|
@ -95,7 +90,6 @@ func bitLen(x Word) (n int) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// log2 computes the integer binary logarithm of x.
|
||||
// The result is the integer n for which 2^n <= x < 2^(n+1).
|
||||
// If x == 0, the result is -1.
|
||||
|
@ -103,13 +97,11 @@ func log2(x Word) int {
|
|||
return bitLen(x) - 1
|
||||
}
|
||||
|
||||
|
||||
// Number of leading zeros in x.
|
||||
func leadingZeros(x Word) uint {
|
||||
return uint(_W - bitLen(x))
|
||||
}
|
||||
|
||||
|
||||
// q = (u1<<_W + u0 - r)/y
|
||||
func divWW(x1, x0, y Word) (q, r Word) { return divWW_g(x1, x0, y) }
|
||||
// Adapted from Warren, Hacker's Delight, p. 152.
|
||||
|
@ -155,7 +147,6 @@ again2:
|
|||
return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s
|
||||
}
|
||||
|
||||
|
||||
func addVV(z, x, y []Word) (c Word) { return addVV_g(z, x, y) }
|
||||
func addVV_g(z, x, y []Word) (c Word) {
|
||||
for i := range z {
|
||||
|
@ -164,7 +155,6 @@ func addVV_g(z, x, y []Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func subVV(z, x, y []Word) (c Word) { return subVV_g(z, x, y) }
|
||||
func subVV_g(z, x, y []Word) (c Word) {
|
||||
for i := range z {
|
||||
|
@ -173,7 +163,6 @@ func subVV_g(z, x, y []Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func addVW(z, x []Word, y Word) (c Word) { return addVW_g(z, x, y) }
|
||||
func addVW_g(z, x []Word, y Word) (c Word) {
|
||||
c = y
|
||||
|
@ -183,7 +172,6 @@ func addVW_g(z, x []Word, y Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func subVW(z, x []Word, y Word) (c Word) { return subVW_g(z, x, y) }
|
||||
func subVW_g(z, x []Word, y Word) (c Word) {
|
||||
c = y
|
||||
|
@ -193,9 +181,8 @@ func subVW_g(z, x []Word, y Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func shlVW(z, x []Word, s Word) (c Word) { return shlVW_g(z, x, s) }
|
||||
func shlVW_g(z, x []Word, s Word) (c Word) {
|
||||
func shlVU(z, x []Word, s uint) (c Word) { return shlVU_g(z, x, s) }
|
||||
func shlVU_g(z, x []Word, s uint) (c Word) {
|
||||
if n := len(z); n > 0 {
|
||||
ŝ := _W - s
|
||||
w1 := x[n-1]
|
||||
|
@ -210,9 +197,8 @@ func shlVW_g(z, x []Word, s Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func shrVW(z, x []Word, s Word) (c Word) { return shrVW_g(z, x, s) }
|
||||
func shrVW_g(z, x []Word, s Word) (c Word) {
|
||||
func shrVU(z, x []Word, s uint) (c Word) { return shrVU_g(z, x, s) }
|
||||
func shrVU_g(z, x []Word, s uint) (c Word) {
|
||||
if n := len(z); n > 0 {
|
||||
ŝ := _W - s
|
||||
w1 := x[0]
|
||||
|
@ -227,7 +213,6 @@ func shrVW_g(z, x []Word, s Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func mulAddVWW(z, x []Word, y, r Word) (c Word) { return mulAddVWW_g(z, x, y, r) }
|
||||
func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
|
||||
c = r
|
||||
|
@ -237,7 +222,6 @@ func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func addMulVVW(z, x []Word, y Word) (c Word) { return addMulVVW_g(z, x, y) }
|
||||
func addMulVVW_g(z, x []Word, y Word) (c Word) {
|
||||
for i := range z {
|
||||
|
@ -248,7 +232,6 @@ func addMulVVW_g(z, x []Word, y Word) (c Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { return divWVW_g(z, xn, x, y) }
|
||||
func divWVW_g(z []Word, xn Word, x []Word, y Word) (r Word) {
|
||||
r = xn
|
||||
|
|
|
@ -11,8 +11,8 @@ func addVV(z, x, y []Word) (c Word)
|
|||
func subVV(z, x, y []Word) (c Word)
|
||||
func addVW(z, x []Word, y Word) (c Word)
|
||||
func subVW(z, x []Word, y Word) (c Word)
|
||||
func shlVW(z, x []Word, s Word) (c Word)
|
||||
func shrVW(z, x []Word, s Word) (c Word)
|
||||
func shlVU(z, x []Word, s uint) (c Word)
|
||||
func shrVU(z, x []Word, s uint) (c Word)
|
||||
func mulAddVWW(z, x []Word, y, r Word) (c Word)
|
||||
func addMulVVW(z, x []Word, y Word) (c Word)
|
||||
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word)
|
||||
|
|
|
@ -6,7 +6,6 @@ package big
|
|||
|
||||
import "testing"
|
||||
|
||||
|
||||
type funWW func(x, y, c Word) (z1, z0 Word)
|
||||
type argWW struct {
|
||||
x, y, c, z1, z0 Word
|
||||
|
@ -26,7 +25,6 @@ var sumWW = []argWW{
|
|||
{_M, _M, 1, 1, _M},
|
||||
}
|
||||
|
||||
|
||||
func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
|
||||
z1, z0 := f(a.x, a.y, a.c)
|
||||
if z1 != a.z1 || z0 != a.z0 {
|
||||
|
@ -34,7 +32,6 @@ func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestFunWW(t *testing.T) {
|
||||
for _, a := range sumWW {
|
||||
arg := a
|
||||
|
@ -51,7 +48,6 @@ func TestFunWW(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type funVV func(z, x, y []Word) (c Word)
|
||||
type argVV struct {
|
||||
z, x, y nat
|
||||
|
@ -70,7 +66,6 @@ var sumVV = []argVV{
|
|||
{nat{0, 0, 0, 0}, nat{_M, 0, _M, 0}, nat{1, _M, 0, _M}, 1},
|
||||
}
|
||||
|
||||
|
||||
func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
|
||||
z := make(nat, len(a.z))
|
||||
c := f(z, a.x, a.y)
|
||||
|
@ -85,7 +80,6 @@ func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestFunVV(t *testing.T) {
|
||||
for _, a := range sumVV {
|
||||
arg := a
|
||||
|
@ -106,7 +100,6 @@ func TestFunVV(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type funVW func(z, x []Word, y Word) (c Word)
|
||||
type argVW struct {
|
||||
z, x nat
|
||||
|
@ -169,7 +162,6 @@ var rshVW = []argVW{
|
|||
{nat{_M, _M, _M >> 20}, nat{_M, _M, _M}, 20, _M << (_W - 20) & _M},
|
||||
}
|
||||
|
||||
|
||||
func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
|
||||
z := make(nat, len(a.z))
|
||||
c := f(z, a.x, a.y)
|
||||
|
@ -184,6 +176,11 @@ func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
|
|||
}
|
||||
}
|
||||
|
||||
func makeFunVW(f func(z, x []Word, s uint) (c Word)) funVW {
|
||||
return func(z, x []Word, s Word) (c Word) {
|
||||
return f(z, x, uint(s))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunVW(t *testing.T) {
|
||||
for _, a := range sumVW {
|
||||
|
@ -196,20 +193,23 @@ func TestFunVW(t *testing.T) {
|
|||
testFunVW(t, "subVW", subVW, arg)
|
||||
}
|
||||
|
||||
shlVW_g := makeFunVW(shlVU_g)
|
||||
shlVW := makeFunVW(shlVU)
|
||||
for _, a := range lshVW {
|
||||
arg := a
|
||||
testFunVW(t, "shlVW_g", shlVW_g, arg)
|
||||
testFunVW(t, "shlVW", shlVW, arg)
|
||||
testFunVW(t, "shlVU_g", shlVW_g, arg)
|
||||
testFunVW(t, "shlVU", shlVW, arg)
|
||||
}
|
||||
|
||||
shrVW_g := makeFunVW(shrVU_g)
|
||||
shrVW := makeFunVW(shrVU)
|
||||
for _, a := range rshVW {
|
||||
arg := a
|
||||
testFunVW(t, "shrVW_g", shrVW_g, arg)
|
||||
testFunVW(t, "shrVW", shrVW, arg)
|
||||
testFunVW(t, "shrVU_g", shrVW_g, arg)
|
||||
testFunVW(t, "shrVU", shrVW, arg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
type funVWW func(z, x []Word, y, r Word) (c Word)
|
||||
type argVWW struct {
|
||||
z, x nat
|
||||
|
@ -243,7 +243,6 @@ var prodVWW = []argVWW{
|
|||
{nat{_M<<7&_M + 1<<6, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 1 << 6, _M >> (_W - 7)},
|
||||
}
|
||||
|
||||
|
||||
func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
|
||||
z := make(nat, len(a.z))
|
||||
c := f(z, a.x, a.y, a.r)
|
||||
|
@ -258,7 +257,6 @@ func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO(gri) mulAddVWW and divWVW are symmetric operations but
|
||||
// their signature is not symmetric. Try to unify.
|
||||
|
||||
|
@ -285,7 +283,6 @@ func testFunWVW(t *testing.T, msg string, f funWVW, a argWVW) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestFunVWW(t *testing.T) {
|
||||
for _, a := range prodVWW {
|
||||
arg := a
|
||||
|
@ -300,7 +297,6 @@ func TestFunVWW(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var mulWWTests = []struct {
|
||||
x, y Word
|
||||
q, r Word
|
||||
|
@ -309,7 +305,6 @@ var mulWWTests = []struct {
|
|||
// 32 bit only: {0xc47dfa8c, 50911, 0x98a4, 0x998587f4},
|
||||
}
|
||||
|
||||
|
||||
func TestMulWW(t *testing.T) {
|
||||
for i, test := range mulWWTests {
|
||||
q, r := mulWW_g(test.x, test.y)
|
||||
|
@ -319,7 +314,6 @@ func TestMulWW(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var mulAddWWWTests = []struct {
|
||||
x, y, c Word
|
||||
q, r Word
|
||||
|
@ -331,7 +325,6 @@ var mulAddWWWTests = []struct {
|
|||
{_M, _M, _M, _M, 0},
|
||||
}
|
||||
|
||||
|
||||
func TestMulAddWWW(t *testing.T) {
|
||||
for i, test := range mulAddWWWTests {
|
||||
q, r := mulAddWWW_g(test.x, test.y, test.c)
|
||||
|
|
|
@ -19,10 +19,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
|
||||
var calibrate = flag.Bool("calibrate", false, "run calibration test")
|
||||
|
||||
|
||||
// measure returns the time to run f
|
||||
func measure(f func()) int64 {
|
||||
const N = 100
|
||||
|
@ -34,7 +32,6 @@ func measure(f func()) int64 {
|
|||
return (stop - start) / N
|
||||
}
|
||||
|
||||
|
||||
func computeThresholds() {
|
||||
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
|
||||
fmt.Printf("(run repeatedly for good results)\n")
|
||||
|
@ -84,7 +81,6 @@ func computeThresholds() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestCalibrate(t *testing.T) {
|
||||
if *calibrate {
|
||||
computeThresholds()
|
||||
|
|
|
@ -13,13 +13,11 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
|
||||
type matrix struct {
|
||||
n, m int
|
||||
a []*Rat
|
||||
}
|
||||
|
||||
|
||||
func (a *matrix) at(i, j int) *Rat {
|
||||
if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
|
||||
panic("index out of range")
|
||||
|
@ -27,7 +25,6 @@ func (a *matrix) at(i, j int) *Rat {
|
|||
return a.a[i*a.m+j]
|
||||
}
|
||||
|
||||
|
||||
func (a *matrix) set(i, j int, x *Rat) {
|
||||
if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
|
||||
panic("index out of range")
|
||||
|
@ -35,7 +32,6 @@ func (a *matrix) set(i, j int, x *Rat) {
|
|||
a.a[i*a.m+j] = x
|
||||
}
|
||||
|
||||
|
||||
func newMatrix(n, m int) *matrix {
|
||||
if !(0 <= n && 0 <= m) {
|
||||
panic("illegal matrix")
|
||||
|
@ -47,7 +43,6 @@ func newMatrix(n, m int) *matrix {
|
|||
return a
|
||||
}
|
||||
|
||||
|
||||
func newUnit(n int) *matrix {
|
||||
a := newMatrix(n, n)
|
||||
for i := 0; i < n; i++ {
|
||||
|
@ -62,7 +57,6 @@ func newUnit(n int) *matrix {
|
|||
return a
|
||||
}
|
||||
|
||||
|
||||
func newHilbert(n int) *matrix {
|
||||
a := newMatrix(n, n)
|
||||
for i := 0; i < n; i++ {
|
||||
|
@ -73,7 +67,6 @@ func newHilbert(n int) *matrix {
|
|||
return a
|
||||
}
|
||||
|
||||
|
||||
func newInverseHilbert(n int) *matrix {
|
||||
a := newMatrix(n, n)
|
||||
for i := 0; i < n; i++ {
|
||||
|
@ -98,7 +91,6 @@ func newInverseHilbert(n int) *matrix {
|
|||
return a
|
||||
}
|
||||
|
||||
|
||||
func (a *matrix) mul(b *matrix) *matrix {
|
||||
if a.m != b.n {
|
||||
panic("illegal matrix multiply")
|
||||
|
@ -116,7 +108,6 @@ func (a *matrix) mul(b *matrix) *matrix {
|
|||
return c
|
||||
}
|
||||
|
||||
|
||||
func (a *matrix) eql(b *matrix) bool {
|
||||
if a.n != b.n || a.m != b.m {
|
||||
return false
|
||||
|
@ -131,7 +122,6 @@ func (a *matrix) eql(b *matrix) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
|
||||
func (a *matrix) String() string {
|
||||
s := ""
|
||||
for i := 0; i < a.n; i++ {
|
||||
|
@ -143,7 +133,6 @@ func (a *matrix) String() string {
|
|||
return s
|
||||
}
|
||||
|
||||
|
||||
func doHilbert(t *testing.T, n int) {
|
||||
a := newHilbert(n)
|
||||
b := newInverseHilbert(n)
|
||||
|
@ -160,12 +149,10 @@ func doHilbert(t *testing.T, n int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestHilbert(t *testing.T) {
|
||||
doHilbert(t, 10)
|
||||
}
|
||||
|
||||
|
||||
func BenchmarkHilbert(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
doHilbert(nil, 10)
|
||||
|
|
|
@ -8,8 +8,10 @@ package big
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"rand"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// An Int represents a signed multi-precision integer.
|
||||
|
@ -19,10 +21,8 @@ type Int struct {
|
|||
abs nat // absolute value of the integer
|
||||
}
|
||||
|
||||
|
||||
var intOne = &Int{false, natOne}
|
||||
|
||||
|
||||
// Sign returns:
|
||||
//
|
||||
// -1 if x < 0
|
||||
|
@ -39,7 +39,6 @@ func (x *Int) Sign() int {
|
|||
return 1
|
||||
}
|
||||
|
||||
|
||||
// SetInt64 sets z to x and returns z.
|
||||
func (z *Int) SetInt64(x int64) *Int {
|
||||
neg := false
|
||||
|
@ -52,13 +51,11 @@ func (z *Int) SetInt64(x int64) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// NewInt allocates and returns a new Int set to x.
|
||||
func NewInt(x int64) *Int {
|
||||
return new(Int).SetInt64(x)
|
||||
}
|
||||
|
||||
|
||||
// Set sets z to x and returns z.
|
||||
func (z *Int) Set(x *Int) *Int {
|
||||
z.abs = z.abs.set(x.abs)
|
||||
|
@ -66,7 +63,6 @@ func (z *Int) Set(x *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Abs sets z to |x| (the absolute value of x) and returns z.
|
||||
func (z *Int) Abs(x *Int) *Int {
|
||||
z.abs = z.abs.set(x.abs)
|
||||
|
@ -74,7 +70,6 @@ func (z *Int) Abs(x *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Neg sets z to -x and returns z.
|
||||
func (z *Int) Neg(x *Int) *Int {
|
||||
z.abs = z.abs.set(x.abs)
|
||||
|
@ -82,7 +77,6 @@ func (z *Int) Neg(x *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Add sets z to the sum x+y and returns z.
|
||||
func (z *Int) Add(x, y *Int) *Int {
|
||||
neg := x.neg
|
||||
|
@ -104,7 +98,6 @@ func (z *Int) Add(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Sub sets z to the difference x-y and returns z.
|
||||
func (z *Int) Sub(x, y *Int) *Int {
|
||||
neg := x.neg
|
||||
|
@ -126,7 +119,6 @@ func (z *Int) Sub(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Mul sets z to the product x*y and returns z.
|
||||
func (z *Int) Mul(x, y *Int) *Int {
|
||||
// x * y == x * y
|
||||
|
@ -138,7 +130,6 @@ func (z *Int) Mul(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// MulRange sets z to the product of all integers
|
||||
// in the range [a, b] inclusively and returns z.
|
||||
// If a > b (empty range), the result is 1.
|
||||
|
@ -162,7 +153,6 @@ func (z *Int) MulRange(a, b int64) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Binomial sets z to the binomial coefficient of (n, k) and returns z.
|
||||
func (z *Int) Binomial(n, k int64) *Int {
|
||||
var a, b Int
|
||||
|
@ -171,7 +161,6 @@ func (z *Int) Binomial(n, k int64) *Int {
|
|||
return z.Quo(&a, &b)
|
||||
}
|
||||
|
||||
|
||||
// Quo sets z to the quotient x/y for y != 0 and returns z.
|
||||
// If y == 0, a division-by-zero run-time panic occurs.
|
||||
// See QuoRem for more details.
|
||||
|
@ -181,7 +170,6 @@ func (z *Int) Quo(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Rem sets z to the remainder x%y for y != 0 and returns z.
|
||||
// If y == 0, a division-by-zero run-time panic occurs.
|
||||
// See QuoRem for more details.
|
||||
|
@ -191,7 +179,6 @@ func (z *Int) Rem(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// QuoRem sets z to the quotient x/y and r to the remainder x%y
|
||||
// and returns the pair (z, r) for y != 0.
|
||||
// If y == 0, a division-by-zero run-time panic occurs.
|
||||
|
@ -209,7 +196,6 @@ func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
|
|||
return z, r
|
||||
}
|
||||
|
||||
|
||||
// Div sets z to the quotient x/y for y != 0 and returns z.
|
||||
// If y == 0, a division-by-zero run-time panic occurs.
|
||||
// See DivMod for more details.
|
||||
|
@ -227,7 +213,6 @@ func (z *Int) Div(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Mod sets z to the modulus x%y for y != 0 and returns z.
|
||||
// If y == 0, a division-by-zero run-time panic occurs.
|
||||
// See DivMod for more details.
|
||||
|
@ -248,7 +233,6 @@ func (z *Int) Mod(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// DivMod sets z to the quotient x div y and m to the modulus x mod y
|
||||
// and returns the pair (z, m) for y != 0.
|
||||
// If y == 0, a division-by-zero run-time panic occurs.
|
||||
|
@ -281,7 +265,6 @@ func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
|
|||
return z, m
|
||||
}
|
||||
|
||||
|
||||
// Cmp compares x and y and returns:
|
||||
//
|
||||
// -1 if x < y
|
||||
|
@ -307,49 +290,197 @@ func (x *Int) Cmp(y *Int) (r int) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func (x *Int) String() string {
|
||||
s := ""
|
||||
if x.neg {
|
||||
s = "-"
|
||||
switch {
|
||||
case x == nil:
|
||||
return "<nil>"
|
||||
case x.neg:
|
||||
return "-" + x.abs.decimalString()
|
||||
}
|
||||
return s + x.abs.string(10)
|
||||
return x.abs.decimalString()
|
||||
}
|
||||
|
||||
|
||||
func fmtbase(ch int) int {
|
||||
func charset(ch int) string {
|
||||
switch ch {
|
||||
case 'b':
|
||||
return 2
|
||||
return lowercaseDigits[0:2]
|
||||
case 'o':
|
||||
return 8
|
||||
case 'd':
|
||||
return 10
|
||||
return lowercaseDigits[0:8]
|
||||
case 'd', 's', 'v':
|
||||
return lowercaseDigits[0:10]
|
||||
case 'x':
|
||||
return 16
|
||||
return lowercaseDigits[0:16]
|
||||
case 'X':
|
||||
return uppercaseDigits[0:16]
|
||||
}
|
||||
return 10
|
||||
return "" // unknown format
|
||||
}
|
||||
|
||||
// write count copies of text to s
|
||||
func writeMultiple(s fmt.State, text string, count int) {
|
||||
if len(text) > 0 {
|
||||
b := []byte(text)
|
||||
for ; count > 0; count-- {
|
||||
s.Write(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Format is a support routine for fmt.Formatter. It accepts
|
||||
// the formats 'b' (binary), 'o' (octal), 'd' (decimal) and
|
||||
// 'x' (hexadecimal).
|
||||
// the formats 'b' (binary), 'o' (octal), 'd' (decimal), 'x'
|
||||
// (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
|
||||
// Also supported are the full suite of package fmt's format
|
||||
// verbs for integral types, including '+', '-', and ' '
|
||||
// for sign control, '#' for leading zero in octal and for
|
||||
// hexadecimal, a leading "0x" or "0X" for "%#x" and "%#X"
|
||||
// respectively, specification of minimum digits precision,
|
||||
// output field width, space or zero padding, and left or
|
||||
// right justification.
|
||||
//
|
||||
func (x *Int) Format(s fmt.State, ch int) {
|
||||
if x == nil {
|
||||
cs := charset(ch)
|
||||
|
||||
// special cases
|
||||
switch {
|
||||
case cs == "":
|
||||
// unknown format
|
||||
fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String())
|
||||
return
|
||||
case x == nil:
|
||||
fmt.Fprint(s, "<nil>")
|
||||
return
|
||||
}
|
||||
if x.neg {
|
||||
fmt.Fprint(s, "-")
|
||||
|
||||
// determine sign character
|
||||
sign := ""
|
||||
switch {
|
||||
case x.neg:
|
||||
sign = "-"
|
||||
case s.Flag('+'): // supersedes ' ' when both specified
|
||||
sign = "+"
|
||||
case s.Flag(' '):
|
||||
sign = " "
|
||||
}
|
||||
fmt.Fprint(s, x.abs.string(fmtbase(ch)))
|
||||
|
||||
// determine prefix characters for indicating output base
|
||||
prefix := ""
|
||||
if s.Flag('#') {
|
||||
switch ch {
|
||||
case 'o': // octal
|
||||
prefix = "0"
|
||||
case 'x': // hexadecimal
|
||||
prefix = "0x"
|
||||
case 'X':
|
||||
prefix = "0X"
|
||||
}
|
||||
}
|
||||
|
||||
// determine digits with base set by len(cs) and digit characters from cs
|
||||
digits := x.abs.string(cs)
|
||||
|
||||
// number of characters for the three classes of number padding
|
||||
var left int // space characters to left of digits for right justification ("%8d")
|
||||
var zeroes int // zero characters (actually cs[0]) as left-most digits ("%.8d")
|
||||
var right int // space characters to right of digits for left justification ("%-8d")
|
||||
|
||||
// determine number padding from precision: the least number of digits to output
|
||||
precision, precisionSet := s.Precision()
|
||||
if precisionSet {
|
||||
switch {
|
||||
case len(digits) < precision:
|
||||
zeroes = precision - len(digits) // count of zero padding
|
||||
case digits == "0" && precision == 0:
|
||||
return // print nothing if zero value (x == 0) and zero precision ("." or ".0")
|
||||
}
|
||||
}
|
||||
|
||||
// determine field pad from width: the least number of characters to output
|
||||
length := len(sign) + len(prefix) + zeroes + len(digits)
|
||||
if width, widthSet := s.Width(); widthSet && length < width { // pad as specified
|
||||
switch d := width - length; {
|
||||
case s.Flag('-'):
|
||||
// pad on the right with spaces; supersedes '0' when both specified
|
||||
right = d
|
||||
case s.Flag('0') && !precisionSet:
|
||||
// pad with zeroes unless precision also specified
|
||||
zeroes = d
|
||||
default:
|
||||
// pad on the left with spaces
|
||||
left = d
|
||||
}
|
||||
}
|
||||
|
||||
// print number as [left pad][sign][prefix][zero pad][digits][right pad]
|
||||
writeMultiple(s, " ", left)
|
||||
writeMultiple(s, sign, 1)
|
||||
writeMultiple(s, prefix, 1)
|
||||
writeMultiple(s, "0", zeroes)
|
||||
writeMultiple(s, digits, 1)
|
||||
writeMultiple(s, " ", right)
|
||||
}
|
||||
|
||||
// scan sets z to the integer value corresponding to the longest possible prefix
|
||||
// read from r representing a signed integer number in a given conversion base.
|
||||
// It returns z, the actual conversion base used, and an error, if any. In the
|
||||
// error case, the value of z is undefined. The syntax follows the syntax of
|
||||
// integer literals in Go.
|
||||
//
|
||||
// The base argument must be 0 or a value from 2 through MaxBase. If the base
|
||||
// is 0, the string prefix determines the actual conversion base. A prefix of
|
||||
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
|
||||
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
|
||||
//
|
||||
func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) {
|
||||
// determine sign
|
||||
ch, _, err := r.ReadRune()
|
||||
if err != nil {
|
||||
return z, 0, err
|
||||
}
|
||||
neg := false
|
||||
switch ch {
|
||||
case '-':
|
||||
neg = true
|
||||
case '+': // nothing to do
|
||||
default:
|
||||
r.UnreadRune()
|
||||
}
|
||||
|
||||
// Int64 returns the int64 representation of z.
|
||||
// If z cannot be represented in an int64, the result is undefined.
|
||||
// determine mantissa
|
||||
z.abs, base, err = z.abs.scan(r, base)
|
||||
if err != nil {
|
||||
return z, base, err
|
||||
}
|
||||
z.neg = len(z.abs) > 0 && neg // 0 has no sign
|
||||
|
||||
return z, base, nil
|
||||
}
|
||||
|
||||
// Scan is a support routine for fmt.Scanner; it sets z to the value of
|
||||
// the scanned number. It accepts the formats 'b' (binary), 'o' (octal),
|
||||
// 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
|
||||
func (z *Int) Scan(s fmt.ScanState, ch int) os.Error {
|
||||
s.SkipSpace() // skip leading space characters
|
||||
base := 0
|
||||
switch ch {
|
||||
case 'b':
|
||||
base = 2
|
||||
case 'o':
|
||||
base = 8
|
||||
case 'd':
|
||||
base = 10
|
||||
case 'x', 'X':
|
||||
base = 16
|
||||
case 's', 'v':
|
||||
// let scan determine the base
|
||||
default:
|
||||
return os.NewError("Int.Scan: invalid verb")
|
||||
}
|
||||
_, _, err := z.scan(s, base)
|
||||
return err
|
||||
}
|
||||
|
||||
// Int64 returns the int64 representation of x.
|
||||
// If x cannot be represented in an int64, the result is undefined.
|
||||
func (x *Int) Int64() int64 {
|
||||
if len(x.abs) == 0 {
|
||||
return 0
|
||||
|
@ -364,40 +495,25 @@ func (x *Int) Int64() int64 {
|
|||
return v
|
||||
}
|
||||
|
||||
|
||||
// SetString sets z to the value of s, interpreted in the given base,
|
||||
// and returns z and a boolean indicating success. If SetString fails,
|
||||
// the value of z is undefined.
|
||||
//
|
||||
// If the base argument is 0, the string prefix determines the actual
|
||||
// conversion base. A prefix of ``0x'' or ``0X'' selects base 16; the
|
||||
// ``0'' prefix selects base 8, and a ``0b'' or ``0B'' prefix selects
|
||||
// base 2. Otherwise the selected base is 10.
|
||||
// The base argument must be 0 or a value from 2 through MaxBase. If the base
|
||||
// is 0, the string prefix determines the actual conversion base. A prefix of
|
||||
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
|
||||
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
|
||||
//
|
||||
func (z *Int) SetString(s string, base int) (*Int, bool) {
|
||||
if len(s) == 0 || base < 0 || base == 1 || 16 < base {
|
||||
r := strings.NewReader(s)
|
||||
_, _, err := z.scan(r, base)
|
||||
if err != nil {
|
||||
return z, false
|
||||
}
|
||||
|
||||
neg := s[0] == '-'
|
||||
if neg || s[0] == '+' {
|
||||
s = s[1:]
|
||||
if len(s) == 0 {
|
||||
return z, false
|
||||
}
|
||||
}
|
||||
|
||||
var scanned int
|
||||
z.abs, _, scanned = z.abs.scan(s, base)
|
||||
if scanned != len(s) {
|
||||
return z, false
|
||||
}
|
||||
z.neg = len(z.abs) > 0 && neg // 0 has no sign
|
||||
|
||||
return z, true
|
||||
_, _, err = r.ReadRune()
|
||||
return z, err == os.EOF // err == os.EOF => scan consumed all of s
|
||||
}
|
||||
|
||||
|
||||
// SetBytes interprets buf as the bytes of a big-endian unsigned
|
||||
// integer, sets z to that value, and returns z.
|
||||
func (z *Int) SetBytes(buf []byte) *Int {
|
||||
|
@ -406,21 +522,18 @@ func (z *Int) SetBytes(buf []byte) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Bytes returns the absolute value of z as a big-endian byte slice.
|
||||
func (z *Int) Bytes() []byte {
|
||||
buf := make([]byte, len(z.abs)*_S)
|
||||
return buf[z.abs.bytes(buf):]
|
||||
}
|
||||
|
||||
|
||||
// BitLen returns the length of the absolute value of z in bits.
|
||||
// The bit length of 0 is 0.
|
||||
func (z *Int) BitLen() int {
|
||||
return z.abs.bitLen()
|
||||
}
|
||||
|
||||
|
||||
// Exp sets z = x**y mod m. If m is nil, z = x**y.
|
||||
// See Knuth, volume 2, section 4.6.3.
|
||||
func (z *Int) Exp(x, y, m *Int) *Int {
|
||||
|
@ -441,7 +554,6 @@ func (z *Int) Exp(x, y, m *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// GcdInt sets d to the greatest common divisor of a and b, which must be
|
||||
// positive numbers.
|
||||
// If x and y are not nil, GcdInt sets x and y such that d = a*x + b*y.
|
||||
|
@ -500,7 +612,6 @@ func GcdInt(d, x, y, a, b *Int) {
|
|||
*d = *A
|
||||
}
|
||||
|
||||
|
||||
// ProbablyPrime performs n Miller-Rabin tests to check whether z is prime.
|
||||
// If it returns true, z is prime with probability 1 - 1/4^n.
|
||||
// If it returns false, z is not prime.
|
||||
|
@ -508,8 +619,7 @@ func ProbablyPrime(z *Int, n int) bool {
|
|||
return !z.neg && z.abs.probablyPrime(n)
|
||||
}
|
||||
|
||||
|
||||
// Rand sets z to a pseudo-random number in [0, n) and returns z.
|
||||
// Rand sets z to a pseudo-random number in [0, n) and returns z.
|
||||
func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
|
||||
z.neg = false
|
||||
if n.neg == true || len(n.abs) == 0 {
|
||||
|
@ -520,7 +630,6 @@ func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// ModInverse sets z to the multiplicative inverse of g in the group ℤ/pℤ (where
|
||||
// p is a prime) and returns z.
|
||||
func (z *Int) ModInverse(g, p *Int) *Int {
|
||||
|
@ -534,7 +643,6 @@ func (z *Int) ModInverse(g, p *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Lsh sets z = x << n and returns z.
|
||||
func (z *Int) Lsh(x *Int, n uint) *Int {
|
||||
z.abs = z.abs.shl(x.abs, n)
|
||||
|
@ -542,7 +650,6 @@ func (z *Int) Lsh(x *Int, n uint) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Rsh sets z = x >> n and returns z.
|
||||
func (z *Int) Rsh(x *Int, n uint) *Int {
|
||||
if x.neg {
|
||||
|
@ -559,6 +666,39 @@ func (z *Int) Rsh(x *Int, n uint) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
// Bit returns the value of the i'th bit of z. That is, it
|
||||
// returns (z>>i)&1. The bit index i must be >= 0.
|
||||
func (z *Int) Bit(i int) uint {
|
||||
if i < 0 {
|
||||
panic("negative bit index")
|
||||
}
|
||||
if z.neg {
|
||||
t := nat{}.sub(z.abs, natOne)
|
||||
return t.bit(uint(i)) ^ 1
|
||||
}
|
||||
|
||||
return z.abs.bit(uint(i))
|
||||
}
|
||||
|
||||
// SetBit sets the i'th bit of z to bit and returns z.
|
||||
// That is, if bit is 1 SetBit sets z = x | (1 << i);
|
||||
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
|
||||
// SetBit will panic.
|
||||
func (z *Int) SetBit(x *Int, i int, b uint) *Int {
|
||||
if i < 0 {
|
||||
panic("negative bit index")
|
||||
}
|
||||
if x.neg {
|
||||
t := z.abs.sub(x.abs, natOne)
|
||||
t = t.setBit(t, uint(i), b^1)
|
||||
z.abs = t.add(t, natOne)
|
||||
z.neg = len(z.abs) > 0
|
||||
return z
|
||||
}
|
||||
z.abs = z.abs.setBit(x.abs, uint(i), b)
|
||||
z.neg = false
|
||||
return z
|
||||
}
|
||||
|
||||
// And sets z = x & y and returns z.
|
||||
func (z *Int) And(x, y *Int) *Int {
|
||||
|
@ -590,7 +730,6 @@ func (z *Int) And(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// AndNot sets z = x &^ y and returns z.
|
||||
func (z *Int) AndNot(x, y *Int) *Int {
|
||||
if x.neg == y.neg {
|
||||
|
@ -624,7 +763,6 @@ func (z *Int) AndNot(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Or sets z = x | y and returns z.
|
||||
func (z *Int) Or(x, y *Int) *Int {
|
||||
if x.neg == y.neg {
|
||||
|
@ -655,7 +793,6 @@ func (z *Int) Or(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Xor sets z = x ^ y and returns z.
|
||||
func (z *Int) Xor(x, y *Int) *Int {
|
||||
if x.neg == y.neg {
|
||||
|
@ -686,7 +823,6 @@ func (z *Int) Xor(x, y *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Not sets z = ^x and returns z.
|
||||
func (z *Int) Not(x *Int) *Int {
|
||||
if x.neg {
|
||||
|
@ -702,15 +838,14 @@ func (z *Int) Not(x *Int) *Int {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Gob codec version. Permits backward-compatible changes to the encoding.
|
||||
const version byte = 1
|
||||
const intGobVersion byte = 1
|
||||
|
||||
// GobEncode implements the gob.GobEncoder interface.
|
||||
func (z *Int) GobEncode() ([]byte, os.Error) {
|
||||
buf := make([]byte, len(z.abs)*_S+1) // extra byte for version and sign bit
|
||||
buf := make([]byte, 1+len(z.abs)*_S) // extra byte for version and sign bit
|
||||
i := z.abs.bytes(buf) - 1 // i >= 0
|
||||
b := version << 1 // make space for sign bit
|
||||
b := intGobVersion << 1 // make space for sign bit
|
||||
if z.neg {
|
||||
b |= 1
|
||||
}
|
||||
|
@ -718,14 +853,13 @@ func (z *Int) GobEncode() ([]byte, os.Error) {
|
|||
return buf[i:], nil
|
||||
}
|
||||
|
||||
|
||||
// GobDecode implements the gob.GobDecoder interface.
|
||||
func (z *Int) GobDecode(buf []byte) os.Error {
|
||||
if len(buf) == 0 {
|
||||
return os.NewError("Int.GobDecode: no data")
|
||||
}
|
||||
b := buf[0]
|
||||
if b>>1 != version {
|
||||
if b>>1 != intGobVersion {
|
||||
return os.NewError(fmt.Sprintf("Int.GobDecode: encoding version %d not supported", b>>1))
|
||||
}
|
||||
z.neg = b&1 != 0
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"testing/quick"
|
||||
)
|
||||
|
||||
|
||||
func isNormalized(x *Int) bool {
|
||||
if len(x.abs) == 0 {
|
||||
return !x.neg
|
||||
|
@ -22,13 +21,11 @@ func isNormalized(x *Int) bool {
|
|||
return x.abs[len(x.abs)-1] != 0
|
||||
}
|
||||
|
||||
|
||||
type funZZ func(z, x, y *Int) *Int
|
||||
type argZZ struct {
|
||||
z, x, y *Int
|
||||
}
|
||||
|
||||
|
||||
var sumZZ = []argZZ{
|
||||
{NewInt(0), NewInt(0), NewInt(0)},
|
||||
{NewInt(1), NewInt(1), NewInt(0)},
|
||||
|
@ -38,7 +35,6 @@ var sumZZ = []argZZ{
|
|||
{NewInt(-1111111110), NewInt(-123456789), NewInt(-987654321)},
|
||||
}
|
||||
|
||||
|
||||
var prodZZ = []argZZ{
|
||||
{NewInt(0), NewInt(0), NewInt(0)},
|
||||
{NewInt(0), NewInt(1), NewInt(0)},
|
||||
|
@ -47,7 +43,6 @@ var prodZZ = []argZZ{
|
|||
// TODO(gri) add larger products
|
||||
}
|
||||
|
||||
|
||||
func TestSignZ(t *testing.T) {
|
||||
var zero Int
|
||||
for _, a := range sumZZ {
|
||||
|
@ -59,7 +54,6 @@ func TestSignZ(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestSetZ(t *testing.T) {
|
||||
for _, a := range sumZZ {
|
||||
var z Int
|
||||
|
@ -73,7 +67,6 @@ func TestSetZ(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestAbsZ(t *testing.T) {
|
||||
var zero Int
|
||||
for _, a := range sumZZ {
|
||||
|
@ -90,7 +83,6 @@ func TestAbsZ(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
|
||||
var z Int
|
||||
f(&z, a.x, a.y)
|
||||
|
@ -102,7 +94,6 @@ func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestSumZZ(t *testing.T) {
|
||||
AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) }
|
||||
SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) }
|
||||
|
@ -121,7 +112,6 @@ func TestSumZZ(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestProdZZ(t *testing.T) {
|
||||
MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) }
|
||||
for _, a := range prodZZ {
|
||||
|
@ -133,7 +123,6 @@ func TestProdZZ(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// mulBytes returns x*y via grade school multiplication. Both inputs
|
||||
// and the result are assumed to be in big-endian representation (to
|
||||
// match the semantics of Int.Bytes and Int.SetBytes).
|
||||
|
@ -166,7 +155,6 @@ func mulBytes(x, y []byte) []byte {
|
|||
return z[i:]
|
||||
}
|
||||
|
||||
|
||||
func checkMul(a, b []byte) bool {
|
||||
var x, y, z1 Int
|
||||
x.SetBytes(a)
|
||||
|
@ -179,14 +167,12 @@ func checkMul(a, b []byte) bool {
|
|||
return z1.Cmp(&z2) == 0
|
||||
}
|
||||
|
||||
|
||||
func TestMul(t *testing.T) {
|
||||
if err := quick.Check(checkMul, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
var mulRangesZ = []struct {
|
||||
a, b int64
|
||||
prod string
|
||||
|
@ -212,7 +198,6 @@ var mulRangesZ = []struct {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
func TestMulRangeZ(t *testing.T) {
|
||||
var tmp Int
|
||||
// test entirely positive ranges
|
||||
|
@ -231,7 +216,6 @@ func TestMulRangeZ(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var stringTests = []struct {
|
||||
in string
|
||||
out string
|
||||
|
@ -280,7 +264,6 @@ var stringTests = []struct {
|
|||
{"1001010111", "1001010111", 2, 0x257, true},
|
||||
}
|
||||
|
||||
|
||||
func format(base int) string {
|
||||
switch base {
|
||||
case 2:
|
||||
|
@ -293,7 +276,6 @@ func format(base int) string {
|
|||
return "%d"
|
||||
}
|
||||
|
||||
|
||||
func TestGetString(t *testing.T) {
|
||||
z := new(Int)
|
||||
for i, test := range stringTests {
|
||||
|
@ -316,7 +298,6 @@ func TestGetString(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestSetString(t *testing.T) {
|
||||
tmp := new(Int)
|
||||
for i, test := range stringTests {
|
||||
|
@ -347,6 +328,212 @@ func TestSetString(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var formatTests = []struct {
|
||||
input string
|
||||
format string
|
||||
output string
|
||||
}{
|
||||
{"<nil>", "%x", "<nil>"},
|
||||
{"<nil>", "%#x", "<nil>"},
|
||||
{"<nil>", "%#y", "%!y(big.Int=<nil>)"},
|
||||
|
||||
{"10", "%b", "1010"},
|
||||
{"10", "%o", "12"},
|
||||
{"10", "%d", "10"},
|
||||
{"10", "%v", "10"},
|
||||
{"10", "%x", "a"},
|
||||
{"10", "%X", "A"},
|
||||
{"-10", "%X", "-A"},
|
||||
{"10", "%y", "%!y(big.Int=10)"},
|
||||
{"-10", "%y", "%!y(big.Int=-10)"},
|
||||
|
||||
{"10", "%#b", "1010"},
|
||||
{"10", "%#o", "012"},
|
||||
{"10", "%#d", "10"},
|
||||
{"10", "%#v", "10"},
|
||||
{"10", "%#x", "0xa"},
|
||||
{"10", "%#X", "0XA"},
|
||||
{"-10", "%#X", "-0XA"},
|
||||
{"10", "%#y", "%!y(big.Int=10)"},
|
||||
{"-10", "%#y", "%!y(big.Int=-10)"},
|
||||
|
||||
{"1234", "%d", "1234"},
|
||||
{"1234", "%3d", "1234"},
|
||||
{"1234", "%4d", "1234"},
|
||||
{"-1234", "%d", "-1234"},
|
||||
{"1234", "% 5d", " 1234"},
|
||||
{"1234", "%+5d", "+1234"},
|
||||
{"1234", "%-5d", "1234 "},
|
||||
{"1234", "%x", "4d2"},
|
||||
{"1234", "%X", "4D2"},
|
||||
{"-1234", "%3x", "-4d2"},
|
||||
{"-1234", "%4x", "-4d2"},
|
||||
{"-1234", "%5x", " -4d2"},
|
||||
{"-1234", "%-5x", "-4d2 "},
|
||||
{"1234", "%03d", "1234"},
|
||||
{"1234", "%04d", "1234"},
|
||||
{"1234", "%05d", "01234"},
|
||||
{"1234", "%06d", "001234"},
|
||||
{"-1234", "%06d", "-01234"},
|
||||
{"1234", "%+06d", "+01234"},
|
||||
{"1234", "% 06d", " 01234"},
|
||||
{"1234", "%-6d", "1234 "},
|
||||
{"1234", "%-06d", "1234 "},
|
||||
{"-1234", "%-06d", "-1234 "},
|
||||
|
||||
{"1234", "%.3d", "1234"},
|
||||
{"1234", "%.4d", "1234"},
|
||||
{"1234", "%.5d", "01234"},
|
||||
{"1234", "%.6d", "001234"},
|
||||
{"-1234", "%.3d", "-1234"},
|
||||
{"-1234", "%.4d", "-1234"},
|
||||
{"-1234", "%.5d", "-01234"},
|
||||
{"-1234", "%.6d", "-001234"},
|
||||
|
||||
{"1234", "%8.3d", " 1234"},
|
||||
{"1234", "%8.4d", " 1234"},
|
||||
{"1234", "%8.5d", " 01234"},
|
||||
{"1234", "%8.6d", " 001234"},
|
||||
{"-1234", "%8.3d", " -1234"},
|
||||
{"-1234", "%8.4d", " -1234"},
|
||||
{"-1234", "%8.5d", " -01234"},
|
||||
{"-1234", "%8.6d", " -001234"},
|
||||
|
||||
{"1234", "%+8.3d", " +1234"},
|
||||
{"1234", "%+8.4d", " +1234"},
|
||||
{"1234", "%+8.5d", " +01234"},
|
||||
{"1234", "%+8.6d", " +001234"},
|
||||
{"-1234", "%+8.3d", " -1234"},
|
||||
{"-1234", "%+8.4d", " -1234"},
|
||||
{"-1234", "%+8.5d", " -01234"},
|
||||
{"-1234", "%+8.6d", " -001234"},
|
||||
|
||||
{"1234", "% 8.3d", " 1234"},
|
||||
{"1234", "% 8.4d", " 1234"},
|
||||
{"1234", "% 8.5d", " 01234"},
|
||||
{"1234", "% 8.6d", " 001234"},
|
||||
{"-1234", "% 8.3d", " -1234"},
|
||||
{"-1234", "% 8.4d", " -1234"},
|
||||
{"-1234", "% 8.5d", " -01234"},
|
||||
{"-1234", "% 8.6d", " -001234"},
|
||||
|
||||
{"1234", "%.3x", "4d2"},
|
||||
{"1234", "%.4x", "04d2"},
|
||||
{"1234", "%.5x", "004d2"},
|
||||
{"1234", "%.6x", "0004d2"},
|
||||
{"-1234", "%.3x", "-4d2"},
|
||||
{"-1234", "%.4x", "-04d2"},
|
||||
{"-1234", "%.5x", "-004d2"},
|
||||
{"-1234", "%.6x", "-0004d2"},
|
||||
|
||||
{"1234", "%8.3x", " 4d2"},
|
||||
{"1234", "%8.4x", " 04d2"},
|
||||
{"1234", "%8.5x", " 004d2"},
|
||||
{"1234", "%8.6x", " 0004d2"},
|
||||
{"-1234", "%8.3x", " -4d2"},
|
||||
{"-1234", "%8.4x", " -04d2"},
|
||||
{"-1234", "%8.5x", " -004d2"},
|
||||
{"-1234", "%8.6x", " -0004d2"},
|
||||
|
||||
{"1234", "%+8.3x", " +4d2"},
|
||||
{"1234", "%+8.4x", " +04d2"},
|
||||
{"1234", "%+8.5x", " +004d2"},
|
||||
{"1234", "%+8.6x", " +0004d2"},
|
||||
{"-1234", "%+8.3x", " -4d2"},
|
||||
{"-1234", "%+8.4x", " -04d2"},
|
||||
{"-1234", "%+8.5x", " -004d2"},
|
||||
{"-1234", "%+8.6x", " -0004d2"},
|
||||
|
||||
{"1234", "% 8.3x", " 4d2"},
|
||||
{"1234", "% 8.4x", " 04d2"},
|
||||
{"1234", "% 8.5x", " 004d2"},
|
||||
{"1234", "% 8.6x", " 0004d2"},
|
||||
{"1234", "% 8.7x", " 00004d2"},
|
||||
{"1234", "% 8.8x", " 000004d2"},
|
||||
{"-1234", "% 8.3x", " -4d2"},
|
||||
{"-1234", "% 8.4x", " -04d2"},
|
||||
{"-1234", "% 8.5x", " -004d2"},
|
||||
{"-1234", "% 8.6x", " -0004d2"},
|
||||
{"-1234", "% 8.7x", "-00004d2"},
|
||||
{"-1234", "% 8.8x", "-000004d2"},
|
||||
|
||||
{"1234", "%-8.3d", "1234 "},
|
||||
{"1234", "%-8.4d", "1234 "},
|
||||
{"1234", "%-8.5d", "01234 "},
|
||||
{"1234", "%-8.6d", "001234 "},
|
||||
{"1234", "%-8.7d", "0001234 "},
|
||||
{"1234", "%-8.8d", "00001234"},
|
||||
{"-1234", "%-8.3d", "-1234 "},
|
||||
{"-1234", "%-8.4d", "-1234 "},
|
||||
{"-1234", "%-8.5d", "-01234 "},
|
||||
{"-1234", "%-8.6d", "-001234 "},
|
||||
{"-1234", "%-8.7d", "-0001234"},
|
||||
{"-1234", "%-8.8d", "-00001234"},
|
||||
|
||||
{"16777215", "%b", "111111111111111111111111"}, // 2**24 - 1
|
||||
|
||||
{"0", "%.d", ""},
|
||||
{"0", "%.0d", ""},
|
||||
{"0", "%3.d", ""},
|
||||
}
|
||||
|
||||
func TestFormat(t *testing.T) {
|
||||
for i, test := range formatTests {
|
||||
var x *Int
|
||||
if test.input != "<nil>" {
|
||||
var ok bool
|
||||
x, ok = new(Int).SetString(test.input, 0)
|
||||
if !ok {
|
||||
t.Errorf("#%d failed reading input %s", i, test.input)
|
||||
}
|
||||
}
|
||||
output := fmt.Sprintf(test.format, x)
|
||||
if output != test.output {
|
||||
t.Errorf("#%d got %q; want %q, {%q, %q, %q}", i, output, test.output, test.input, test.format, test.output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var scanTests = []struct {
|
||||
input string
|
||||
format string
|
||||
output string
|
||||
remaining int
|
||||
}{
|
||||
{"1010", "%b", "10", 0},
|
||||
{"0b1010", "%v", "10", 0},
|
||||
{"12", "%o", "10", 0},
|
||||
{"012", "%v", "10", 0},
|
||||
{"10", "%d", "10", 0},
|
||||
{"10", "%v", "10", 0},
|
||||
{"a", "%x", "10", 0},
|
||||
{"0xa", "%v", "10", 0},
|
||||
{"A", "%X", "10", 0},
|
||||
{"-A", "%X", "-10", 0},
|
||||
{"+0b1011001", "%v", "89", 0},
|
||||
{"0xA", "%v", "10", 0},
|
||||
{"0 ", "%v", "0", 1},
|
||||
{"2+3", "%v", "2", 2},
|
||||
{"0XABC 12", "%v", "2748", 3},
|
||||
}
|
||||
|
||||
func TestScan(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
for i, test := range scanTests {
|
||||
x := new(Int)
|
||||
buf.Reset()
|
||||
buf.WriteString(test.input)
|
||||
if _, err := fmt.Fscanf(&buf, test.format, x); err != nil {
|
||||
t.Errorf("#%d error: %s", i, err.String())
|
||||
}
|
||||
if x.String() != test.output {
|
||||
t.Errorf("#%d got %s; want %s", i, x.String(), test.output)
|
||||
}
|
||||
if buf.Len() != test.remaining {
|
||||
t.Errorf("#%d got %d bytes remaining; want %d", i, buf.Len(), test.remaining)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Examples from the Go Language Spec, section "Arithmetic operators"
|
||||
var divisionSignsTests = []struct {
|
||||
|
@ -362,7 +549,6 @@ var divisionSignsTests = []struct {
|
|||
{8, 4, 2, 0, 2, 0},
|
||||
}
|
||||
|
||||
|
||||
func TestDivisionSigns(t *testing.T) {
|
||||
for i, test := range divisionSignsTests {
|
||||
x := NewInt(test.x)
|
||||
|
@ -420,7 +606,6 @@ func TestDivisionSigns(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func checkSetBytes(b []byte) bool {
|
||||
hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes())
|
||||
hex2 := hex.EncodeToString(b)
|
||||
|
@ -436,27 +621,23 @@ func checkSetBytes(b []byte) bool {
|
|||
return hex1 == hex2
|
||||
}
|
||||
|
||||
|
||||
func TestSetBytes(t *testing.T) {
|
||||
if err := quick.Check(checkSetBytes, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func checkBytes(b []byte) bool {
|
||||
b2 := new(Int).SetBytes(b).Bytes()
|
||||
return bytes.Compare(b, b2) == 0
|
||||
}
|
||||
|
||||
|
||||
func TestBytes(t *testing.T) {
|
||||
if err := quick.Check(checkSetBytes, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func checkQuo(x, y []byte) bool {
|
||||
u := new(Int).SetBytes(x)
|
||||
v := new(Int).SetBytes(y)
|
||||
|
@ -479,7 +660,6 @@ func checkQuo(x, y []byte) bool {
|
|||
return uprime.Cmp(u) == 0
|
||||
}
|
||||
|
||||
|
||||
var quoTests = []struct {
|
||||
x, y string
|
||||
q, r string
|
||||
|
@ -498,7 +678,6 @@ var quoTests = []struct {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
func TestQuo(t *testing.T) {
|
||||
if err := quick.Check(checkQuo, nil); err != nil {
|
||||
t.Error(err)
|
||||
|
@ -519,7 +698,6 @@ func TestQuo(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestQuoStepD6(t *testing.T) {
|
||||
// See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises
|
||||
// a code path which only triggers 1 in 10^{-19} cases.
|
||||
|
@ -539,7 +717,6 @@ func TestQuoStepD6(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var bitLenTests = []struct {
|
||||
in string
|
||||
out int
|
||||
|
@ -558,7 +735,6 @@ var bitLenTests = []struct {
|
|||
{"-0x4000000000000000000000", 87},
|
||||
}
|
||||
|
||||
|
||||
func TestBitLen(t *testing.T) {
|
||||
for i, test := range bitLenTests {
|
||||
x, ok := new(Int).SetString(test.in, 0)
|
||||
|
@ -573,7 +749,6 @@ func TestBitLen(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var expTests = []struct {
|
||||
x, y, m string
|
||||
out string
|
||||
|
@ -598,7 +773,6 @@ var expTests = []struct {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
func TestExp(t *testing.T) {
|
||||
for i, test := range expTests {
|
||||
x, ok1 := new(Int).SetString(test.x, 0)
|
||||
|
@ -629,7 +803,6 @@ func TestExp(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func checkGcd(aBytes, bBytes []byte) bool {
|
||||
a := new(Int).SetBytes(aBytes)
|
||||
b := new(Int).SetBytes(bBytes)
|
||||
|
@ -646,7 +819,6 @@ func checkGcd(aBytes, bBytes []byte) bool {
|
|||
return x.Cmp(d) == 0
|
||||
}
|
||||
|
||||
|
||||
var gcdTests = []struct {
|
||||
a, b int64
|
||||
d, x, y int64
|
||||
|
@ -654,7 +826,6 @@ var gcdTests = []struct {
|
|||
{120, 23, 1, -9, 47},
|
||||
}
|
||||
|
||||
|
||||
func TestGcd(t *testing.T) {
|
||||
for i, test := range gcdTests {
|
||||
a := NewInt(test.a)
|
||||
|
@ -680,7 +851,6 @@ func TestGcd(t *testing.T) {
|
|||
quick.Check(checkGcd, nil)
|
||||
}
|
||||
|
||||
|
||||
var primes = []string{
|
||||
"2",
|
||||
"3",
|
||||
|
@ -706,7 +876,6 @@ var primes = []string{
|
|||
"203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123",
|
||||
}
|
||||
|
||||
|
||||
var composites = []string{
|
||||
"21284175091214687912771199898307297748211672914763848041968395774954376176754",
|
||||
"6084766654921918907427900243509372380954290099172559290432744450051395395951",
|
||||
|
@ -714,7 +883,6 @@ var composites = []string{
|
|||
"82793403787388584738507275144194252681",
|
||||
}
|
||||
|
||||
|
||||
func TestProbablyPrime(t *testing.T) {
|
||||
nreps := 20
|
||||
if testing.Short() {
|
||||
|
@ -738,14 +906,12 @@ func TestProbablyPrime(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type intShiftTest struct {
|
||||
in string
|
||||
shift uint
|
||||
out string
|
||||
}
|
||||
|
||||
|
||||
var rshTests = []intShiftTest{
|
||||
{"0", 0, "0"},
|
||||
{"-0", 0, "0"},
|
||||
|
@ -773,7 +939,6 @@ var rshTests = []intShiftTest{
|
|||
{"340282366920938463463374607431768211456", 128, "1"},
|
||||
}
|
||||
|
||||
|
||||
func TestRsh(t *testing.T) {
|
||||
for i, test := range rshTests {
|
||||
in, _ := new(Int).SetString(test.in, 10)
|
||||
|
@ -789,7 +954,6 @@ func TestRsh(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRshSelf(t *testing.T) {
|
||||
for i, test := range rshTests {
|
||||
z, _ := new(Int).SetString(test.in, 10)
|
||||
|
@ -805,7 +969,6 @@ func TestRshSelf(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var lshTests = []intShiftTest{
|
||||
{"0", 0, "0"},
|
||||
{"0", 1, "0"},
|
||||
|
@ -828,7 +991,6 @@ var lshTests = []intShiftTest{
|
|||
{"1", 128, "340282366920938463463374607431768211456"},
|
||||
}
|
||||
|
||||
|
||||
func TestLsh(t *testing.T) {
|
||||
for i, test := range lshTests {
|
||||
in, _ := new(Int).SetString(test.in, 10)
|
||||
|
@ -844,7 +1006,6 @@ func TestLsh(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestLshSelf(t *testing.T) {
|
||||
for i, test := range lshTests {
|
||||
z, _ := new(Int).SetString(test.in, 10)
|
||||
|
@ -860,7 +1021,6 @@ func TestLshSelf(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestLshRsh(t *testing.T) {
|
||||
for i, test := range rshTests {
|
||||
in, _ := new(Int).SetString(test.in, 10)
|
||||
|
@ -888,7 +1048,6 @@ func TestLshRsh(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var int64Tests = []int64{
|
||||
0,
|
||||
1,
|
||||
|
@ -902,7 +1061,6 @@ var int64Tests = []int64{
|
|||
-9223372036854775808,
|
||||
}
|
||||
|
||||
|
||||
func TestInt64(t *testing.T) {
|
||||
for i, testVal := range int64Tests {
|
||||
in := NewInt(testVal)
|
||||
|
@ -914,7 +1072,6 @@ func TestInt64(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var bitwiseTests = []struct {
|
||||
x, y string
|
||||
and, or, xor, andNot string
|
||||
|
@ -958,7 +1115,6 @@ var bitwiseTests = []struct {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
type bitFun func(z, x, y *Int) *Int
|
||||
|
||||
func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
|
||||
|
@ -971,7 +1127,6 @@ func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
|
||||
self := new(Int)
|
||||
self.Set(x)
|
||||
|
@ -984,6 +1139,142 @@ func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
|
|||
}
|
||||
}
|
||||
|
||||
func altBit(x *Int, i int) uint {
|
||||
z := new(Int).Rsh(x, uint(i))
|
||||
z = z.And(z, NewInt(1))
|
||||
if z.Cmp(new(Int)) != 0 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func altSetBit(z *Int, x *Int, i int, b uint) *Int {
|
||||
one := NewInt(1)
|
||||
m := one.Lsh(one, uint(i))
|
||||
switch b {
|
||||
case 1:
|
||||
return z.Or(x, m)
|
||||
case 0:
|
||||
return z.AndNot(x, m)
|
||||
}
|
||||
panic("set bit is not 0 or 1")
|
||||
}
|
||||
|
||||
func testBitset(t *testing.T, x *Int) {
|
||||
n := x.BitLen()
|
||||
z := new(Int).Set(x)
|
||||
z1 := new(Int).Set(x)
|
||||
for i := 0; i < n+10; i++ {
|
||||
old := z.Bit(i)
|
||||
old1 := altBit(z1, i)
|
||||
if old != old1 {
|
||||
t.Errorf("bitset: inconsistent value for Bit(%s, %d), got %v want %v", z1, i, old, old1)
|
||||
}
|
||||
z := new(Int).SetBit(z, i, 1)
|
||||
z1 := altSetBit(new(Int), z1, i, 1)
|
||||
if z.Bit(i) == 0 {
|
||||
t.Errorf("bitset: bit %d of %s got 0 want 1", i, x)
|
||||
}
|
||||
if z.Cmp(z1) != 0 {
|
||||
t.Errorf("bitset: inconsistent value after SetBit 1, got %s want %s", z, z1)
|
||||
}
|
||||
z.SetBit(z, i, 0)
|
||||
altSetBit(z1, z1, i, 0)
|
||||
if z.Bit(i) != 0 {
|
||||
t.Errorf("bitset: bit %d of %s got 1 want 0", i, x)
|
||||
}
|
||||
if z.Cmp(z1) != 0 {
|
||||
t.Errorf("bitset: inconsistent value after SetBit 0, got %s want %s", z, z1)
|
||||
}
|
||||
altSetBit(z1, z1, i, old)
|
||||
z.SetBit(z, i, old)
|
||||
if z.Cmp(z1) != 0 {
|
||||
t.Errorf("bitset: inconsistent value after SetBit old, got %s want %s", z, z1)
|
||||
}
|
||||
}
|
||||
if z.Cmp(x) != 0 {
|
||||
t.Errorf("bitset: got %s want %s", z, x)
|
||||
}
|
||||
}
|
||||
|
||||
var bitsetTests = []struct {
|
||||
x string
|
||||
i int
|
||||
b uint
|
||||
}{
|
||||
{"0", 0, 0},
|
||||
{"0", 200, 0},
|
||||
{"1", 0, 1},
|
||||
{"1", 1, 0},
|
||||
{"-1", 0, 1},
|
||||
{"-1", 200, 1},
|
||||
{"0x2000000000000000000000000000", 108, 0},
|
||||
{"0x2000000000000000000000000000", 109, 1},
|
||||
{"0x2000000000000000000000000000", 110, 0},
|
||||
{"-0x2000000000000000000000000001", 108, 1},
|
||||
{"-0x2000000000000000000000000001", 109, 0},
|
||||
{"-0x2000000000000000000000000001", 110, 1},
|
||||
}
|
||||
|
||||
func TestBitSet(t *testing.T) {
|
||||
for _, test := range bitwiseTests {
|
||||
x := new(Int)
|
||||
x.SetString(test.x, 0)
|
||||
testBitset(t, x)
|
||||
x = new(Int)
|
||||
x.SetString(test.y, 0)
|
||||
testBitset(t, x)
|
||||
}
|
||||
for i, test := range bitsetTests {
|
||||
x := new(Int)
|
||||
x.SetString(test.x, 0)
|
||||
b := x.Bit(test.i)
|
||||
if b != test.b {
|
||||
|
||||
t.Errorf("#%d want %v got %v", i, test.b, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBitset(b *testing.B) {
|
||||
z := new(Int)
|
||||
z.SetBit(z, 512, 1)
|
||||
b.ResetTimer()
|
||||
b.StartTimer()
|
||||
for i := b.N - 1; i >= 0; i-- {
|
||||
z.SetBit(z, i&512, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBitsetNeg(b *testing.B) {
|
||||
z := NewInt(-1)
|
||||
z.SetBit(z, 512, 0)
|
||||
b.ResetTimer()
|
||||
b.StartTimer()
|
||||
for i := b.N - 1; i >= 0; i-- {
|
||||
z.SetBit(z, i&512, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBitsetOrig(b *testing.B) {
|
||||
z := new(Int)
|
||||
altSetBit(z, z, 512, 1)
|
||||
b.ResetTimer()
|
||||
b.StartTimer()
|
||||
for i := b.N - 1; i >= 0; i-- {
|
||||
altSetBit(z, z, i&512, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBitsetNegOrig(b *testing.B) {
|
||||
z := NewInt(-1)
|
||||
altSetBit(z, z, 512, 0)
|
||||
b.ResetTimer()
|
||||
b.StartTimer()
|
||||
for i := b.N - 1; i >= 0; i-- {
|
||||
altSetBit(z, z, i&512, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitwise(t *testing.T) {
|
||||
x := new(Int)
|
||||
|
@ -1003,7 +1294,6 @@ func TestBitwise(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var notTests = []struct {
|
||||
in string
|
||||
out string
|
||||
|
@ -1037,7 +1327,6 @@ func TestNot(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var modInverseTests = []struct {
|
||||
element string
|
||||
prime string
|
||||
|
@ -1062,7 +1351,7 @@ func TestModInverse(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// used by TestIntGobEncoding and TestRatGobEncoding
|
||||
var gobEncodingTests = []string{
|
||||
"0",
|
||||
"1",
|
||||
|
@ -1073,7 +1362,7 @@ var gobEncodingTests = []string{
|
|||
"298472983472983471903246121093472394872319615612417471234712061",
|
||||
}
|
||||
|
||||
func TestGobEncoding(t *testing.T) {
|
||||
func TestIntGobEncoding(t *testing.T) {
|
||||
var medium bytes.Buffer
|
||||
enc := gob.NewEncoder(&medium)
|
||||
dec := gob.NewDecoder(&medium)
|
||||
|
@ -1081,7 +1370,8 @@ func TestGobEncoding(t *testing.T) {
|
|||
for j := 0; j < 2; j++ {
|
||||
medium.Reset() // empty buffer for each test case (in case of failures)
|
||||
stest := test
|
||||
if j == 0 {
|
||||
if j != 0 {
|
||||
// negative numbers
|
||||
stest = "-" + test
|
||||
}
|
||||
var tx Int
|
||||
|
|
|
@ -18,7 +18,11 @@ package big
|
|||
// These are the building blocks for the operations on signed integers
|
||||
// and rationals.
|
||||
|
||||
import "rand"
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"rand"
|
||||
)
|
||||
|
||||
// An unsigned integer x of the form
|
||||
//
|
||||
|
@ -40,14 +44,12 @@ var (
|
|||
natTen = nat{10}
|
||||
)
|
||||
|
||||
|
||||
func (z nat) clear() {
|
||||
for i := range z {
|
||||
z[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func (z nat) norm() nat {
|
||||
i := len(z)
|
||||
for i > 0 && z[i-1] == 0 {
|
||||
|
@ -56,7 +58,6 @@ func (z nat) norm() nat {
|
|||
return z[0:i]
|
||||
}
|
||||
|
||||
|
||||
func (z nat) make(n int) nat {
|
||||
if n <= cap(z) {
|
||||
return z[0:n] // reuse z
|
||||
|
@ -67,7 +68,6 @@ func (z nat) make(n int) nat {
|
|||
return make(nat, n, n+e)
|
||||
}
|
||||
|
||||
|
||||
func (z nat) setWord(x Word) nat {
|
||||
if x == 0 {
|
||||
return z.make(0)
|
||||
|
@ -77,7 +77,6 @@ func (z nat) setWord(x Word) nat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
func (z nat) setUint64(x uint64) nat {
|
||||
// single-digit values
|
||||
if w := Word(x); uint64(w) == x {
|
||||
|
@ -100,14 +99,12 @@ func (z nat) setUint64(x uint64) nat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
func (z nat) set(x nat) nat {
|
||||
z = z.make(len(x))
|
||||
copy(z, x)
|
||||
return z
|
||||
}
|
||||
|
||||
|
||||
func (z nat) add(x, y nat) nat {
|
||||
m := len(x)
|
||||
n := len(y)
|
||||
|
@ -134,7 +131,6 @@ func (z nat) add(x, y nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
func (z nat) sub(x, y nat) nat {
|
||||
m := len(x)
|
||||
n := len(y)
|
||||
|
@ -163,7 +159,6 @@ func (z nat) sub(x, y nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
func (x nat) cmp(y nat) (r int) {
|
||||
m := len(x)
|
||||
n := len(y)
|
||||
|
@ -191,7 +186,6 @@ func (x nat) cmp(y nat) (r int) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func (z nat) mulAddWW(x nat, y, r Word) nat {
|
||||
m := len(x)
|
||||
if m == 0 || y == 0 {
|
||||
|
@ -205,7 +199,6 @@ func (z nat) mulAddWW(x nat, y, r Word) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// basicMul multiplies x and y and leaves the result in z.
|
||||
// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
|
||||
func basicMul(z, x, y nat) {
|
||||
|
@ -217,7 +210,6 @@ func basicMul(z, x, y nat) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
|
||||
// Factored out for readability - do not use outside karatsuba.
|
||||
func karatsubaAdd(z, x nat, n int) {
|
||||
|
@ -226,7 +218,6 @@ func karatsubaAdd(z, x nat, n int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Like karatsubaAdd, but does subtract.
|
||||
func karatsubaSub(z, x nat, n int) {
|
||||
if c := subVV(z[0:n], z, x); c != 0 {
|
||||
|
@ -234,7 +225,6 @@ func karatsubaSub(z, x nat, n int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Operands that are shorter than karatsubaThreshold are multiplied using
|
||||
// "grade school" multiplication; for longer operands the Karatsuba algorithm
|
||||
// is used.
|
||||
|
@ -339,13 +329,11 @@ func karatsuba(z, x, y nat) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// alias returns true if x and y share the same base array.
|
||||
func alias(x, y nat) bool {
|
||||
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
|
||||
}
|
||||
|
||||
|
||||
// addAt implements z += x*(1<<(_W*i)); z must be long enough.
|
||||
// (we don't use nat.add because we need z to stay the same
|
||||
// slice, and we don't need to normalize z after each addition)
|
||||
|
@ -360,7 +348,6 @@ func addAt(z, x nat, i int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func max(x, y int) int {
|
||||
if x > y {
|
||||
return x
|
||||
|
@ -368,7 +355,6 @@ func max(x, y int) int {
|
|||
return y
|
||||
}
|
||||
|
||||
|
||||
// karatsubaLen computes an approximation to the maximum k <= n such that
|
||||
// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
|
||||
// result is the largest number that can be divided repeatedly by 2 before
|
||||
|
@ -382,7 +368,6 @@ func karatsubaLen(n int) int {
|
|||
return n << i
|
||||
}
|
||||
|
||||
|
||||
func (z nat) mul(x, y nat) nat {
|
||||
m := len(x)
|
||||
n := len(y)
|
||||
|
@ -450,7 +435,6 @@ func (z nat) mul(x, y nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// mulRange computes the product of all the unsigned integers in the
|
||||
// range [a, b] inclusively. If a > b (empty range), the result is 1.
|
||||
func (z nat) mulRange(a, b uint64) nat {
|
||||
|
@ -469,7 +453,6 @@ func (z nat) mulRange(a, b uint64) nat {
|
|||
return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
|
||||
}
|
||||
|
||||
|
||||
// q = (x-r)/y, with 0 <= r < y
|
||||
func (z nat) divW(x nat, y Word) (q nat, r Word) {
|
||||
m := len(x)
|
||||
|
@ -490,7 +473,6 @@ func (z nat) divW(x nat, y Word) (q nat, r Word) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
func (z nat) div(z2, u, v nat) (q, r nat) {
|
||||
if len(v) == 0 {
|
||||
panic("division by zero")
|
||||
|
@ -518,7 +500,6 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// q = (uIn-r)/v, with 0 <= r < y
|
||||
// Uses z as storage for q, and u as storage for r if possible.
|
||||
// See Knuth, Volume 2, section 4.3.1, Algorithm D.
|
||||
|
@ -545,9 +526,14 @@ func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
|
|||
u.clear()
|
||||
|
||||
// D1.
|
||||
shift := Word(leadingZeros(v[n-1]))
|
||||
shlVW(v, v, shift)
|
||||
u[len(uIn)] = shlVW(u[0:len(uIn)], uIn, shift)
|
||||
shift := leadingZeros(v[n-1])
|
||||
if shift > 0 {
|
||||
// do not modify v, it may be used by another goroutine simultaneously
|
||||
v1 := make(nat, n)
|
||||
shlVU(v1, v, shift)
|
||||
v = v1
|
||||
}
|
||||
u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
|
||||
|
||||
// D2.
|
||||
for j := m; j >= 0; j-- {
|
||||
|
@ -586,14 +572,12 @@ func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
|
|||
}
|
||||
|
||||
q = q.norm()
|
||||
shrVW(u, u, shift)
|
||||
shrVW(v, v, shift)
|
||||
shrVU(u, u, shift)
|
||||
r = u.norm()
|
||||
|
||||
return q, r
|
||||
}
|
||||
|
||||
|
||||
// Length of x in bits. x must be normalized.
|
||||
func (x nat) bitLen() int {
|
||||
if i := len(x) - 1; i >= 0 {
|
||||
|
@ -602,103 +586,253 @@ func (x nat) bitLen() int {
|
|||
return 0
|
||||
}
|
||||
|
||||
// MaxBase is the largest number base accepted for string conversions.
|
||||
const MaxBase = 'z' - 'a' + 10 + 1 // = hexValue('z') + 1
|
||||
|
||||
func hexValue(ch byte) int {
|
||||
var d byte
|
||||
|
||||
func hexValue(ch int) Word {
|
||||
d := MaxBase + 1 // illegal base
|
||||
switch {
|
||||
case '0' <= ch && ch <= '9':
|
||||
d = ch - '0'
|
||||
case 'a' <= ch && ch <= 'f':
|
||||
case 'a' <= ch && ch <= 'z':
|
||||
d = ch - 'a' + 10
|
||||
case 'A' <= ch && ch <= 'F':
|
||||
case 'A' <= ch && ch <= 'Z':
|
||||
d = ch - 'A' + 10
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
return int(d)
|
||||
return Word(d)
|
||||
}
|
||||
|
||||
// scan sets z to the natural number corresponding to the longest possible prefix
|
||||
// read from r representing an unsigned integer in a given conversion base.
|
||||
// It returns z, the actual conversion base used, and an error, if any. In the
|
||||
// error case, the value of z is undefined. The syntax follows the syntax of
|
||||
// unsigned integer literals in Go.
|
||||
//
|
||||
// The base argument must be 0 or a value from 2 through MaxBase. If the base
|
||||
// is 0, the string prefix determines the actual conversion base. A prefix of
|
||||
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
|
||||
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
|
||||
//
|
||||
func (z nat) scan(r io.RuneScanner, base int) (nat, int, os.Error) {
|
||||
// reject illegal bases
|
||||
if base < 0 || base == 1 || MaxBase < base {
|
||||
return z, 0, os.NewError("illegal number base")
|
||||
}
|
||||
|
||||
// one char look-ahead
|
||||
ch, _, err := r.ReadRune()
|
||||
if err != nil {
|
||||
return z, 0, err
|
||||
}
|
||||
|
||||
// scan returns the natural number corresponding to the
|
||||
// longest possible prefix of s representing a natural number in a
|
||||
// given conversion base, the actual conversion base used, and the
|
||||
// prefix length. The syntax of natural numbers follows the syntax
|
||||
// of unsigned integer literals in Go.
|
||||
//
|
||||
// If the base argument is 0, the string prefix determines the actual
|
||||
// conversion base. A prefix of ``0x'' or ``0X'' selects base 16; the
|
||||
// ``0'' prefix selects base 8, and a ``0b'' or ``0B'' prefix selects
|
||||
// base 2. Otherwise the selected base is 10.
|
||||
//
|
||||
func (z nat) scan(s string, base int) (nat, int, int) {
|
||||
// determine base if necessary
|
||||
i, n := 0, len(s)
|
||||
b := Word(base)
|
||||
if base == 0 {
|
||||
base = 10
|
||||
if n > 0 && s[0] == '0' {
|
||||
base, i = 8, 1
|
||||
if n > 1 {
|
||||
switch s[1] {
|
||||
b = 10
|
||||
if ch == '0' {
|
||||
switch ch, _, err = r.ReadRune(); err {
|
||||
case nil:
|
||||
b = 8
|
||||
switch ch {
|
||||
case 'x', 'X':
|
||||
base, i = 16, 2
|
||||
b = 16
|
||||
case 'b', 'B':
|
||||
base, i = 2, 2
|
||||
b = 2
|
||||
}
|
||||
if b == 2 || b == 16 {
|
||||
if ch, _, err = r.ReadRune(); err != nil {
|
||||
return z, 0, err
|
||||
}
|
||||
}
|
||||
case os.EOF:
|
||||
return z, 10, nil
|
||||
default:
|
||||
return z, 10, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convert string
|
||||
// - group as many digits d as possible together into a "super-digit" dd with "super-base" bb
|
||||
// - only when bb does not fit into a word anymore, do a full number mulAddWW using bb and dd
|
||||
z = z.make(0)
|
||||
bb := Word(1)
|
||||
dd := Word(0)
|
||||
for max := _M / b; ; {
|
||||
d := hexValue(ch)
|
||||
if d >= b {
|
||||
r.UnreadRune() // ch does not belong to number anymore
|
||||
break
|
||||
}
|
||||
|
||||
if bb <= max {
|
||||
bb *= b
|
||||
dd = dd*b + d
|
||||
} else {
|
||||
// bb * b would overflow
|
||||
z = z.mulAddWW(z, bb, dd)
|
||||
bb = b
|
||||
dd = d
|
||||
}
|
||||
|
||||
if ch, _, err = r.ReadRune(); err != nil {
|
||||
if err != os.EOF {
|
||||
return z, int(b), err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case bb > 1:
|
||||
// there was at least one mantissa digit
|
||||
z = z.mulAddWW(z, bb, dd)
|
||||
case base == 0 && b == 8:
|
||||
// there was only the octal prefix 0 (possibly followed by digits > 7);
|
||||
// return base 10, not 8
|
||||
return z, 10, nil
|
||||
case base != 0 || b != 8:
|
||||
// there was neither a mantissa digit nor the octal prefix 0
|
||||
return z, int(b), os.NewError("syntax error scanning number")
|
||||
}
|
||||
|
||||
return z.norm(), int(b), nil
|
||||
}
|
||||
|
||||
// Character sets for string conversion.
|
||||
const (
|
||||
lowercaseDigits = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
uppercaseDigits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
)
|
||||
|
||||
// decimalString returns a decimal representation of x.
|
||||
// It calls x.string with the charset "0123456789".
|
||||
func (x nat) decimalString() string {
|
||||
return x.string(lowercaseDigits[0:10])
|
||||
}
|
||||
|
||||
// string converts x to a string using digits from a charset; a digit with
|
||||
// value d is represented by charset[d]. The conversion base is determined
|
||||
// by len(charset), which must be >= 2.
|
||||
func (x nat) string(charset string) string {
|
||||
b := Word(len(charset))
|
||||
|
||||
// special cases
|
||||
switch {
|
||||
case b < 2 || b > 256:
|
||||
panic("illegal base")
|
||||
case len(x) == 0:
|
||||
return string(charset[0])
|
||||
}
|
||||
|
||||
// allocate buffer for conversion
|
||||
i := x.bitLen()/log2(b) + 1 // +1: round up
|
||||
s := make([]byte, i)
|
||||
|
||||
// special case: power of two bases can avoid divisions completely
|
||||
if b == b&-b {
|
||||
// shift is base-b digit size in bits
|
||||
shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2
|
||||
mask := Word(1)<<shift - 1
|
||||
w := x[0]
|
||||
nbits := uint(_W) // number of unprocessed bits in w
|
||||
|
||||
// convert less-significant words
|
||||
for k := 1; k < len(x); k++ {
|
||||
// convert full digits
|
||||
for nbits >= shift {
|
||||
i--
|
||||
s[i] = charset[w&mask]
|
||||
w >>= shift
|
||||
nbits -= shift
|
||||
}
|
||||
|
||||
// convert any partial leading digit and advance to next word
|
||||
if nbits == 0 {
|
||||
// no partial digit remaining, just advance
|
||||
w = x[k]
|
||||
nbits = _W
|
||||
} else {
|
||||
// partial digit in current (k-1) and next (k) word
|
||||
w |= x[k] << nbits
|
||||
i--
|
||||
s[i] = charset[w&mask]
|
||||
|
||||
// advance
|
||||
w = x[k] >> (shift - nbits)
|
||||
nbits = _W - (shift - nbits)
|
||||
}
|
||||
}
|
||||
|
||||
// convert digits of most-significant word (omit leading zeros)
|
||||
for nbits >= 0 && w != 0 {
|
||||
i--
|
||||
s[i] = charset[w&mask]
|
||||
w >>= shift
|
||||
nbits -= shift
|
||||
}
|
||||
|
||||
return string(s[i:])
|
||||
}
|
||||
|
||||
// general case: extract groups of digits by multiprecision division
|
||||
|
||||
// maximize ndigits where b**ndigits < 2^_W; bb (big base) is b**ndigits
|
||||
bb := Word(1)
|
||||
ndigits := 0
|
||||
for max := Word(_M / b); bb <= max; bb *= b {
|
||||
ndigits++
|
||||
}
|
||||
|
||||
// preserve x, create local copy for use in repeated divisions
|
||||
q := nat(nil).set(x)
|
||||
var r Word
|
||||
|
||||
// convert
|
||||
if b == 10 { // hard-coding for 10 here speeds this up by 1.25x
|
||||
for len(q) > 0 {
|
||||
// extract least significant, base bb "digit"
|
||||
q, r = q.divW(q, bb) // N.B. >82% of time is here. Optimize divW
|
||||
if len(q) == 0 {
|
||||
// skip leading zeros in most-significant group of digits
|
||||
for j := 0; j < ndigits && r != 0; j++ {
|
||||
i--
|
||||
s[i] = charset[r%10]
|
||||
r /= 10
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < ndigits; j++ {
|
||||
i--
|
||||
s[i] = charset[r%10]
|
||||
r /= 10
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for len(q) > 0 {
|
||||
// extract least significant group of digits
|
||||
q, r = q.divW(q, bb) // N.B. >82% of time is here. Optimize divW
|
||||
if len(q) == 0 {
|
||||
// skip leading zeros in most-significant group of digits
|
||||
for j := 0; j < ndigits && r != 0; j++ {
|
||||
i--
|
||||
s[i] = charset[r%b]
|
||||
r /= b
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < ndigits; j++ {
|
||||
i--
|
||||
s[i] = charset[r%b]
|
||||
r /= b
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reject illegal bases or strings consisting only of prefix
|
||||
if base < 2 || 16 < base || (base != 8 && i >= n) {
|
||||
return z, 0, 0
|
||||
}
|
||||
|
||||
// convert string
|
||||
z = z.make(0)
|
||||
for ; i < n; i++ {
|
||||
d := hexValue(s[i])
|
||||
if 0 <= d && d < base {
|
||||
z = z.mulAddWW(z, Word(base), Word(d))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return z.norm(), base, i
|
||||
}
|
||||
|
||||
|
||||
// string converts x to a string for a given base, with 2 <= base <= 16.
|
||||
// TODO(gri) in the style of the other routines, perhaps this should take
|
||||
// a []byte buffer and return it
|
||||
func (x nat) string(base int) string {
|
||||
if base < 2 || 16 < base {
|
||||
panic("illegal base")
|
||||
}
|
||||
|
||||
if len(x) == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
// allocate buffer for conversion
|
||||
i := x.bitLen()/log2(Word(base)) + 1 // +1: round up
|
||||
s := make([]byte, i)
|
||||
|
||||
// don't destroy x
|
||||
q := nat(nil).set(x)
|
||||
|
||||
// convert
|
||||
for len(q) > 0 {
|
||||
i--
|
||||
var r Word
|
||||
q, r = q.divW(q, Word(base))
|
||||
s[i] = "0123456789abcdef"[r]
|
||||
}
|
||||
|
||||
return string(s[i:])
|
||||
}
|
||||
|
||||
|
||||
const deBruijn32 = 0x077CB531
|
||||
|
||||
var deBruijn32Lookup = []byte{
|
||||
|
@ -721,7 +855,7 @@ var deBruijn64Lookup = []byte{
|
|||
func trailingZeroBits(x Word) int {
|
||||
// x & -x leaves only the right-most bit set in the word. Let k be the
|
||||
// index of that bit. Since only a single bit is set, the value is two
|
||||
// to the power of k. Multipling by a power of two is equivalent to
|
||||
// to the power of k. Multiplying by a power of two is equivalent to
|
||||
// left shifting, in this case by k bits. The de Bruijn constant is
|
||||
// such that all six bit, consecutive substrings are distinct.
|
||||
// Therefore, if we have a left shifted version of this constant we can
|
||||
|
@ -739,7 +873,6 @@ func trailingZeroBits(x Word) int {
|
|||
return 0
|
||||
}
|
||||
|
||||
|
||||
// z = x << s
|
||||
func (z nat) shl(x nat, s uint) nat {
|
||||
m := len(x)
|
||||
|
@ -750,13 +883,12 @@ func (z nat) shl(x nat, s uint) nat {
|
|||
|
||||
n := m + int(s/_W)
|
||||
z = z.make(n + 1)
|
||||
z[n] = shlVW(z[n-m:n], x, Word(s%_W))
|
||||
z[n] = shlVU(z[n-m:n], x, s%_W)
|
||||
z[0 : n-m].clear()
|
||||
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// z = x >> s
|
||||
func (z nat) shr(x nat, s uint) nat {
|
||||
m := len(x)
|
||||
|
@ -767,11 +899,45 @@ func (z nat) shr(x nat, s uint) nat {
|
|||
// n > 0
|
||||
|
||||
z = z.make(n)
|
||||
shrVW(z, x[m-n:], Word(s%_W))
|
||||
shrVU(z, x[m-n:], s%_W)
|
||||
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
func (z nat) setBit(x nat, i uint, b uint) nat {
|
||||
j := int(i / _W)
|
||||
m := Word(1) << (i % _W)
|
||||
n := len(x)
|
||||
switch b {
|
||||
case 0:
|
||||
z = z.make(n)
|
||||
copy(z, x)
|
||||
if j >= n {
|
||||
// no need to grow
|
||||
return z
|
||||
}
|
||||
z[j] &^= m
|
||||
return z.norm()
|
||||
case 1:
|
||||
if j >= n {
|
||||
n = j + 1
|
||||
}
|
||||
z = z.make(n)
|
||||
copy(z, x)
|
||||
z[j] |= m
|
||||
// no need to normalize
|
||||
return z
|
||||
}
|
||||
panic("set bit is not 0 or 1")
|
||||
}
|
||||
|
||||
func (z nat) bit(i uint) uint {
|
||||
j := int(i / _W)
|
||||
if j >= len(z) {
|
||||
return 0
|
||||
}
|
||||
return uint(z[j] >> (i % _W) & 1)
|
||||
}
|
||||
|
||||
func (z nat) and(x, y nat) nat {
|
||||
m := len(x)
|
||||
|
@ -789,7 +955,6 @@ func (z nat) and(x, y nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
func (z nat) andNot(x, y nat) nat {
|
||||
m := len(x)
|
||||
n := len(y)
|
||||
|
@ -807,7 +972,6 @@ func (z nat) andNot(x, y nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
func (z nat) or(x, y nat) nat {
|
||||
m := len(x)
|
||||
n := len(y)
|
||||
|
@ -827,7 +991,6 @@ func (z nat) or(x, y nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
func (z nat) xor(x, y nat) nat {
|
||||
m := len(x)
|
||||
n := len(y)
|
||||
|
@ -847,10 +1010,10 @@ func (z nat) xor(x, y nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2)
|
||||
func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 }
|
||||
|
||||
func greaterThan(x1, x2, y1, y2 Word) bool {
|
||||
return x1 > y1 || x1 == y1 && x2 > y2
|
||||
}
|
||||
|
||||
// modW returns x % d.
|
||||
func (x nat) modW(d Word) (r Word) {
|
||||
|
@ -860,30 +1023,29 @@ func (x nat) modW(d Word) (r Word) {
|
|||
return divWVW(q, 0, x, d)
|
||||
}
|
||||
|
||||
|
||||
// powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd.
|
||||
func (n nat) powersOfTwoDecompose() (q nat, k Word) {
|
||||
if len(n) == 0 {
|
||||
return n, 0
|
||||
// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0.
|
||||
func (x nat) powersOfTwoDecompose() (q nat, k int) {
|
||||
if len(x) == 0 {
|
||||
return x, 0
|
||||
}
|
||||
|
||||
zeroWords := 0
|
||||
for n[zeroWords] == 0 {
|
||||
zeroWords++
|
||||
// One of the words must be non-zero by definition,
|
||||
// so this loop will terminate with i < len(x), and
|
||||
// i is the number of 0 words.
|
||||
i := 0
|
||||
for x[i] == 0 {
|
||||
i++
|
||||
}
|
||||
// One of the words must be non-zero by invariant, therefore
|
||||
// zeroWords < len(n).
|
||||
x := trailingZeroBits(n[zeroWords])
|
||||
n := trailingZeroBits(x[i]) // x[i] != 0
|
||||
|
||||
q = make(nat, len(x)-i)
|
||||
shrVU(q, x[i:], uint(n))
|
||||
|
||||
q = q.make(len(n) - zeroWords)
|
||||
shrVW(q, n[zeroWords:], Word(x))
|
||||
q = q.norm()
|
||||
|
||||
k = Word(_W*zeroWords + x)
|
||||
k = i*_W + n
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// random creates a random integer in [0..limit), using the space in z if
|
||||
// possible. n is the bit length of limit.
|
||||
func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
|
||||
|
@ -914,7 +1076,6 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
|
||||
// reuses the storage of z if possible.
|
||||
func (z nat) expNN(x, y, m nat) nat {
|
||||
|
@ -983,7 +1144,6 @@ func (z nat) expNN(x, y, m nat) nat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
|
||||
// If it returns true, n is prime with probability 1 - 1/4^reps.
|
||||
// If it returns false, n is not prime.
|
||||
|
@ -1050,7 +1210,7 @@ NextRandom:
|
|||
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
|
||||
continue
|
||||
}
|
||||
for j := Word(1); j < k; j++ {
|
||||
for j := 1; j < k; j++ {
|
||||
y = y.mul(y, y)
|
||||
quotient, y = quotient.div(y, y, n)
|
||||
if y.cmp(nm1) == 0 {
|
||||
|
@ -1066,7 +1226,6 @@ NextRandom:
|
|||
return true
|
||||
}
|
||||
|
||||
|
||||
// bytes writes the value of z into buf using big-endian encoding.
|
||||
// len(buf) must be >= len(z)*_S. The value of z is encoded in the
|
||||
// slice buf[i:]. The number i of unused bytes at the beginning of
|
||||
|
@ -1088,7 +1247,6 @@ func (z nat) bytes(buf []byte) (i int) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// setBytes interprets buf as the bytes of a big-endian unsigned
|
||||
// integer, sets z to that value, and returns z.
|
||||
func (z nat) setBytes(buf []byte) nat {
|
||||
|
|
|
@ -4,7 +4,12 @@
|
|||
|
||||
package big
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var cmpTests = []struct {
|
||||
x, y nat
|
||||
|
@ -26,7 +31,6 @@ var cmpTests = []struct {
|
|||
{nat{34986, 41, 105, 1957}, nat{56, 7458, 104, 1957}, 1},
|
||||
}
|
||||
|
||||
|
||||
func TestCmp(t *testing.T) {
|
||||
for i, a := range cmpTests {
|
||||
r := a.x.cmp(a.y)
|
||||
|
@ -36,13 +40,11 @@ func TestCmp(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type funNN func(z, x, y nat) nat
|
||||
type argNN struct {
|
||||
z, x, y nat
|
||||
}
|
||||
|
||||
|
||||
var sumNN = []argNN{
|
||||
{},
|
||||
{nat{1}, nil, nat{1}},
|
||||
|
@ -52,7 +54,6 @@ var sumNN = []argNN{
|
|||
{nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}},
|
||||
}
|
||||
|
||||
|
||||
var prodNN = []argNN{
|
||||
{},
|
||||
{nil, nil, nil},
|
||||
|
@ -64,7 +65,6 @@ var prodNN = []argNN{
|
|||
{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
|
||||
}
|
||||
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
for _, a := range sumNN {
|
||||
z := nat(nil).set(a.z)
|
||||
|
@ -74,7 +74,6 @@ func TestSet(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
|
||||
z := f(nil, a.x, a.y)
|
||||
if z.cmp(a.z) != 0 {
|
||||
|
@ -82,7 +81,6 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestFunNN(t *testing.T) {
|
||||
for _, a := range sumNN {
|
||||
arg := a
|
||||
|
@ -107,7 +105,6 @@ func TestFunNN(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var mulRangesN = []struct {
|
||||
a, b uint64
|
||||
prod string
|
||||
|
@ -130,17 +127,15 @@ var mulRangesN = []struct {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
func TestMulRangeN(t *testing.T) {
|
||||
for i, r := range mulRangesN {
|
||||
prod := nat(nil).mulRange(r.a, r.b).string(10)
|
||||
prod := nat(nil).mulRange(r.a, r.b).decimalString()
|
||||
if prod != r.prod {
|
||||
t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
var mulArg, mulTmp nat
|
||||
|
||||
func init() {
|
||||
|
@ -151,7 +146,6 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func benchmarkMulLoad() {
|
||||
for j := 1; j <= 10; j++ {
|
||||
x := mulArg[0 : j*100]
|
||||
|
@ -159,46 +153,376 @@ func benchmarkMulLoad() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func BenchmarkMul(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkMulLoad()
|
||||
}
|
||||
}
|
||||
|
||||
func toString(x nat, charset string) string {
|
||||
base := len(charset)
|
||||
|
||||
var tab = []struct {
|
||||
x nat
|
||||
b int
|
||||
s string
|
||||
}{
|
||||
{nil, 10, "0"},
|
||||
{nat{1}, 10, "1"},
|
||||
{nat{10}, 10, "10"},
|
||||
{nat{1234567890}, 10, "1234567890"},
|
||||
// special cases
|
||||
switch {
|
||||
case base < 2:
|
||||
panic("illegal base")
|
||||
case len(x) == 0:
|
||||
return string(charset[0])
|
||||
}
|
||||
|
||||
// allocate buffer for conversion
|
||||
i := x.bitLen()/log2(Word(base)) + 1 // +1: round up
|
||||
s := make([]byte, i)
|
||||
|
||||
// don't destroy x
|
||||
q := nat(nil).set(x)
|
||||
|
||||
// convert
|
||||
for len(q) > 0 {
|
||||
i--
|
||||
var r Word
|
||||
q, r = q.divW(q, Word(base))
|
||||
s[i] = charset[r]
|
||||
}
|
||||
|
||||
return string(s[i:])
|
||||
}
|
||||
|
||||
var strTests = []struct {
|
||||
x nat // nat value to be converted
|
||||
c string // conversion charset
|
||||
s string // expected result
|
||||
}{
|
||||
{nil, "01", "0"},
|
||||
{nat{1}, "01", "1"},
|
||||
{nat{0xc5}, "01", "11000101"},
|
||||
{nat{03271}, lowercaseDigits[0:8], "3271"},
|
||||
{nat{10}, lowercaseDigits[0:10], "10"},
|
||||
{nat{1234567890}, uppercaseDigits[0:10], "1234567890"},
|
||||
{nat{0xdeadbeef}, lowercaseDigits[0:16], "deadbeef"},
|
||||
{nat{0xdeadbeef}, uppercaseDigits[0:16], "DEADBEEF"},
|
||||
{nat{0x229be7}, lowercaseDigits[0:17], "1a2b3c"},
|
||||
{nat{0x309663e6}, uppercaseDigits[0:32], "O9COV6"},
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
for _, a := range tab {
|
||||
s := a.x.string(a.b)
|
||||
for _, a := range strTests {
|
||||
s := a.x.string(a.c)
|
||||
if s != a.s {
|
||||
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
|
||||
}
|
||||
|
||||
x, b, n := nat(nil).scan(a.s, a.b)
|
||||
x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
|
||||
if x.cmp(a.x) != 0 {
|
||||
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
|
||||
}
|
||||
if b != a.b {
|
||||
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.b)
|
||||
if b != len(a.c) {
|
||||
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, len(a.c))
|
||||
}
|
||||
if n != len(a.s) {
|
||||
t.Errorf("scan%+v\n\tgot n = %d; want %d", a, n, len(a.s))
|
||||
if err != nil {
|
||||
t.Errorf("scan%+v\n\tgot error = %s", a, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var natScanTests = []struct {
|
||||
s string // string to be scanned
|
||||
base int // input base
|
||||
x nat // expected nat
|
||||
b int // expected base
|
||||
ok bool // expected success
|
||||
next int // next character (or 0, if at EOF)
|
||||
}{
|
||||
// error: illegal base
|
||||
{base: -1},
|
||||
{base: 1},
|
||||
{base: 37},
|
||||
|
||||
// error: no mantissa
|
||||
{},
|
||||
{s: "?"},
|
||||
{base: 10},
|
||||
{base: 36},
|
||||
{s: "?", base: 10},
|
||||
{s: "0x"},
|
||||
{s: "345", base: 2},
|
||||
|
||||
// no errors
|
||||
{"0", 0, nil, 10, true, 0},
|
||||
{"0", 10, nil, 10, true, 0},
|
||||
{"0", 36, nil, 36, true, 0},
|
||||
{"1", 0, nat{1}, 10, true, 0},
|
||||
{"1", 10, nat{1}, 10, true, 0},
|
||||
{"0 ", 0, nil, 10, true, ' '},
|
||||
{"08", 0, nil, 10, true, '8'},
|
||||
{"018", 0, nat{1}, 8, true, '8'},
|
||||
{"0b1", 0, nat{1}, 2, true, 0},
|
||||
{"0b11000101", 0, nat{0xc5}, 2, true, 0},
|
||||
{"03271", 0, nat{03271}, 8, true, 0},
|
||||
{"10ab", 0, nat{10}, 10, true, 'a'},
|
||||
{"1234567890", 0, nat{1234567890}, 10, true, 0},
|
||||
{"xyz", 36, nat{(33*36+34)*36 + 35}, 36, true, 0},
|
||||
{"xyz?", 36, nat{(33*36+34)*36 + 35}, 36, true, '?'},
|
||||
{"0x", 16, nil, 16, true, 'x'},
|
||||
{"0xdeadbeef", 0, nat{0xdeadbeef}, 16, true, 0},
|
||||
{"0XDEADBEEF", 0, nat{0xdeadbeef}, 16, true, 0},
|
||||
}
|
||||
|
||||
func TestScanBase(t *testing.T) {
|
||||
for _, a := range natScanTests {
|
||||
r := strings.NewReader(a.s)
|
||||
x, b, err := nat(nil).scan(r, a.base)
|
||||
if err == nil && !a.ok {
|
||||
t.Errorf("scan%+v\n\texpected error", a)
|
||||
}
|
||||
if err != nil {
|
||||
if a.ok {
|
||||
t.Errorf("scan%+v\n\tgot error = %s", a, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if x.cmp(a.x) != 0 {
|
||||
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
|
||||
}
|
||||
if b != a.b {
|
||||
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.base)
|
||||
}
|
||||
next, _, err := r.ReadRune()
|
||||
if err == os.EOF {
|
||||
next = 0
|
||||
err = nil
|
||||
}
|
||||
if err == nil && next != a.next {
|
||||
t.Errorf("scan%+v\n\tgot next = %q; want %q", a, next, a.next)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pi = "3" +
|
||||
"14159265358979323846264338327950288419716939937510582097494459230781640628620899862803482534211706798214808651" +
|
||||
"32823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461" +
|
||||
"28475648233786783165271201909145648566923460348610454326648213393607260249141273724587006606315588174881520920" +
|
||||
"96282925409171536436789259036001133053054882046652138414695194151160943305727036575959195309218611738193261179" +
|
||||
"31051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798" +
|
||||
"60943702770539217176293176752384674818467669405132000568127145263560827785771342757789609173637178721468440901" +
|
||||
"22495343014654958537105079227968925892354201995611212902196086403441815981362977477130996051870721134999999837" +
|
||||
"29780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083" +
|
||||
"81420617177669147303598253490428755468731159562863882353787593751957781857780532171226806613001927876611195909" +
|
||||
"21642019893809525720106548586327886593615338182796823030195203530185296899577362259941389124972177528347913151" +
|
||||
"55748572424541506959508295331168617278558890750983817546374649393192550604009277016711390098488240128583616035" +
|
||||
"63707660104710181942955596198946767837449448255379774726847104047534646208046684259069491293313677028989152104" +
|
||||
"75216205696602405803815019351125338243003558764024749647326391419927260426992279678235478163600934172164121992" +
|
||||
"45863150302861829745557067498385054945885869269956909272107975093029553211653449872027559602364806654991198818" +
|
||||
"34797753566369807426542527862551818417574672890977772793800081647060016145249192173217214772350141441973568548" +
|
||||
"16136115735255213347574184946843852332390739414333454776241686251898356948556209921922218427255025425688767179" +
|
||||
"04946016534668049886272327917860857843838279679766814541009538837863609506800642251252051173929848960841284886" +
|
||||
"26945604241965285022210661186306744278622039194945047123713786960956364371917287467764657573962413890865832645" +
|
||||
"99581339047802759009946576407895126946839835259570982582262052248940772671947826848260147699090264013639443745" +
|
||||
"53050682034962524517493996514314298091906592509372216964615157098583874105978859597729754989301617539284681382" +
|
||||
"68683868942774155991855925245953959431049972524680845987273644695848653836736222626099124608051243884390451244" +
|
||||
"13654976278079771569143599770012961608944169486855584840635342207222582848864815845602850601684273945226746767" +
|
||||
"88952521385225499546667278239864565961163548862305774564980355936345681743241125150760694794510965960940252288" +
|
||||
"79710893145669136867228748940560101503308617928680920874760917824938589009714909675985261365549781893129784821" +
|
||||
"68299894872265880485756401427047755513237964145152374623436454285844479526586782105114135473573952311342716610" +
|
||||
"21359695362314429524849371871101457654035902799344037420073105785390621983874478084784896833214457138687519435" +
|
||||
"06430218453191048481005370614680674919278191197939952061419663428754440643745123718192179998391015919561814675" +
|
||||
"14269123974894090718649423196156794520809514655022523160388193014209376213785595663893778708303906979207734672" +
|
||||
"21825625996615014215030680384477345492026054146659252014974428507325186660021324340881907104863317346496514539" +
|
||||
"05796268561005508106658796998163574736384052571459102897064140110971206280439039759515677157700420337869936007" +
|
||||
"23055876317635942187312514712053292819182618612586732157919841484882916447060957527069572209175671167229109816" +
|
||||
"90915280173506712748583222871835209353965725121083579151369882091444210067510334671103141267111369908658516398" +
|
||||
"31501970165151168517143765761835155650884909989859982387345528331635507647918535893226185489632132933089857064" +
|
||||
"20467525907091548141654985946163718027098199430992448895757128289059232332609729971208443357326548938239119325" +
|
||||
"97463667305836041428138830320382490375898524374417029132765618093773444030707469211201913020330380197621101100" +
|
||||
"44929321516084244485963766983895228684783123552658213144957685726243344189303968642624341077322697802807318915" +
|
||||
"44110104468232527162010526522721116603966655730925471105578537634668206531098965269186205647693125705863566201" +
|
||||
"85581007293606598764861179104533488503461136576867532494416680396265797877185560845529654126654085306143444318" +
|
||||
"58676975145661406800700237877659134401712749470420562230538994561314071127000407854733269939081454664645880797" +
|
||||
"27082668306343285878569830523580893306575740679545716377525420211495576158140025012622859413021647155097925923" +
|
||||
"09907965473761255176567513575178296664547791745011299614890304639947132962107340437518957359614589019389713111" +
|
||||
"79042978285647503203198691514028708085990480109412147221317947647772622414254854540332157185306142288137585043" +
|
||||
"06332175182979866223717215916077166925474873898665494945011465406284336639379003976926567214638530673609657120" +
|
||||
"91807638327166416274888800786925602902284721040317211860820419000422966171196377921337575114959501566049631862" +
|
||||
"94726547364252308177036751590673502350728354056704038674351362222477158915049530984448933309634087807693259939" +
|
||||
"78054193414473774418426312986080998886874132604721569516239658645730216315981931951673538129741677294786724229" +
|
||||
"24654366800980676928238280689964004824354037014163149658979409243237896907069779422362508221688957383798623001" +
|
||||
"59377647165122893578601588161755782973523344604281512627203734314653197777416031990665541876397929334419521541" +
|
||||
"34189948544473456738316249934191318148092777710386387734317720754565453220777092120190516609628049092636019759" +
|
||||
"88281613323166636528619326686336062735676303544776280350450777235547105859548702790814356240145171806246436267" +
|
||||
"94561275318134078330336254232783944975382437205835311477119926063813346776879695970309833913077109870408591337"
|
||||
|
||||
// Test case for BenchmarkScanPi.
|
||||
func TestScanPi(t *testing.T) {
|
||||
var x nat
|
||||
z, _, err := x.scan(strings.NewReader(pi), 10)
|
||||
if err != nil {
|
||||
t.Errorf("scanning pi: %s", err)
|
||||
}
|
||||
if s := z.decimalString(); s != pi {
|
||||
t.Errorf("scanning pi: got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkScanPi(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var x nat
|
||||
x.scan(strings.NewReader(pi), 10)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// 314**271
|
||||
// base 2: 2249 digits
|
||||
// base 8: 751 digits
|
||||
// base 10: 678 digits
|
||||
// base 16: 563 digits
|
||||
shortBase = 314
|
||||
shortExponent = 271
|
||||
|
||||
// 3141**2178
|
||||
// base 2: 31577 digits
|
||||
// base 8: 10527 digits
|
||||
// base 10: 9507 digits
|
||||
// base 16: 7895 digits
|
||||
mediumBase = 3141
|
||||
mediumExponent = 2718
|
||||
|
||||
// 3141**2178
|
||||
// base 2: 406078 digits
|
||||
// base 8: 135360 digits
|
||||
// base 10: 122243 digits
|
||||
// base 16: 101521 digits
|
||||
longBase = 31415
|
||||
longExponent = 27182
|
||||
)
|
||||
|
||||
func BenchmarkScanShort2(b *testing.B) {
|
||||
ScanHelper(b, 2, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanShort8(b *testing.B) {
|
||||
ScanHelper(b, 8, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanSort10(b *testing.B) {
|
||||
ScanHelper(b, 10, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanShort16(b *testing.B) {
|
||||
ScanHelper(b, 16, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanMedium2(b *testing.B) {
|
||||
ScanHelper(b, 2, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanMedium8(b *testing.B) {
|
||||
ScanHelper(b, 8, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanMedium10(b *testing.B) {
|
||||
ScanHelper(b, 10, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanMedium16(b *testing.B) {
|
||||
ScanHelper(b, 16, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanLong2(b *testing.B) {
|
||||
ScanHelper(b, 2, longBase, longExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanLong8(b *testing.B) {
|
||||
ScanHelper(b, 8, longBase, longExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanLong10(b *testing.B) {
|
||||
ScanHelper(b, 10, longBase, longExponent)
|
||||
}
|
||||
|
||||
func BenchmarkScanLong16(b *testing.B) {
|
||||
ScanHelper(b, 16, longBase, longExponent)
|
||||
}
|
||||
|
||||
func ScanHelper(b *testing.B, base int, xv, yv Word) {
|
||||
b.StopTimer()
|
||||
var x, y, z nat
|
||||
x = x.setWord(xv)
|
||||
y = y.setWord(yv)
|
||||
z = z.expNN(x, y, nil)
|
||||
|
||||
var s string
|
||||
s = z.string(lowercaseDigits[0:base])
|
||||
if t := toString(z, lowercaseDigits[0:base]); t != s {
|
||||
panic(fmt.Sprintf("scanning: got %s; want %s", s, t))
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
x.scan(strings.NewReader(s), base)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringShort2(b *testing.B) {
|
||||
StringHelper(b, 2, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringShort8(b *testing.B) {
|
||||
StringHelper(b, 8, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringShort10(b *testing.B) {
|
||||
StringHelper(b, 10, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringShort16(b *testing.B) {
|
||||
StringHelper(b, 16, shortBase, shortExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringMedium2(b *testing.B) {
|
||||
StringHelper(b, 2, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringMedium8(b *testing.B) {
|
||||
StringHelper(b, 8, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringMedium10(b *testing.B) {
|
||||
StringHelper(b, 10, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringMedium16(b *testing.B) {
|
||||
StringHelper(b, 16, mediumBase, mediumExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringLong2(b *testing.B) {
|
||||
StringHelper(b, 2, longBase, longExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringLong8(b *testing.B) {
|
||||
StringHelper(b, 8, longBase, longExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringLong10(b *testing.B) {
|
||||
StringHelper(b, 10, longBase, longExponent)
|
||||
}
|
||||
|
||||
func BenchmarkStringLong16(b *testing.B) {
|
||||
StringHelper(b, 16, longBase, longExponent)
|
||||
}
|
||||
|
||||
func StringHelper(b *testing.B, base int, xv, yv Word) {
|
||||
b.StopTimer()
|
||||
var x, y, z nat
|
||||
x = x.setWord(xv)
|
||||
y = y.setWord(yv)
|
||||
z = z.expNN(x, y, nil)
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
z.string(lowercaseDigits[0:base])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeadingZeros(t *testing.T) {
|
||||
var x Word = _B >> 1
|
||||
|
@ -210,14 +534,12 @@ func TestLeadingZeros(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type shiftTest struct {
|
||||
in nat
|
||||
shift uint
|
||||
out nat
|
||||
}
|
||||
|
||||
|
||||
var leftShiftTests = []shiftTest{
|
||||
{nil, 0, nil},
|
||||
{nil, 1, nil},
|
||||
|
@ -227,7 +549,6 @@ var leftShiftTests = []shiftTest{
|
|||
{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
|
||||
}
|
||||
|
||||
|
||||
func TestShiftLeft(t *testing.T) {
|
||||
for i, test := range leftShiftTests {
|
||||
var z nat
|
||||
|
@ -241,7 +562,6 @@ func TestShiftLeft(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var rightShiftTests = []shiftTest{
|
||||
{nil, 0, nil},
|
||||
{nil, 1, nil},
|
||||
|
@ -252,7 +572,6 @@ var rightShiftTests = []shiftTest{
|
|||
{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
|
||||
}
|
||||
|
||||
|
||||
func TestShiftRight(t *testing.T) {
|
||||
for i, test := range rightShiftTests {
|
||||
var z nat
|
||||
|
@ -266,24 +585,20 @@ func TestShiftRight(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type modWTest struct {
|
||||
in string
|
||||
dividend string
|
||||
out string
|
||||
}
|
||||
|
||||
|
||||
var modWTests32 = []modWTest{
|
||||
{"23492635982634928349238759823742", "252341", "220170"},
|
||||
}
|
||||
|
||||
|
||||
var modWTests64 = []modWTest{
|
||||
{"6527895462947293856291561095690465243862946", "524326975699234", "375066989628668"},
|
||||
}
|
||||
|
||||
|
||||
func runModWTests(t *testing.T, tests []modWTest) {
|
||||
for i, test := range tests {
|
||||
in, _ := new(Int).SetString(test.in, 10)
|
||||
|
@ -297,7 +612,6 @@ func runModWTests(t *testing.T, tests []modWTest) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestModW(t *testing.T) {
|
||||
if _W >= 32 {
|
||||
runModWTests(t, modWTests32)
|
||||
|
@ -307,7 +621,6 @@ func TestModW(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestTrailingZeroBits(t *testing.T) {
|
||||
var x Word
|
||||
x--
|
||||
|
@ -319,7 +632,6 @@ func TestTrailingZeroBits(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var expNNTests = []struct {
|
||||
x, y, m string
|
||||
out string
|
||||
|
@ -337,17 +649,16 @@ var expNNTests = []struct {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
func TestExpNN(t *testing.T) {
|
||||
for i, test := range expNNTests {
|
||||
x, _, _ := nat(nil).scan(test.x, 0)
|
||||
y, _, _ := nat(nil).scan(test.y, 0)
|
||||
out, _, _ := nat(nil).scan(test.out, 0)
|
||||
x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
|
||||
y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
|
||||
out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
|
||||
|
||||
var m nat
|
||||
|
||||
if len(test.m) > 0 {
|
||||
m, _, _ = nat(nil).scan(test.m, 0)
|
||||
m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
|
||||
}
|
||||
|
||||
z := nat(nil).expNN(x, y, m)
|
||||
|
|
|
@ -6,7 +6,12 @@
|
|||
|
||||
package big
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A Rat represents a quotient a/b of arbitrary precision. The zero value for
|
||||
// a Rat, 0/0, is not a legal Rat.
|
||||
|
@ -15,13 +20,11 @@ type Rat struct {
|
|||
b nat
|
||||
}
|
||||
|
||||
|
||||
// NewRat creates a new Rat with numerator a and denominator b.
|
||||
func NewRat(a, b int64) *Rat {
|
||||
return new(Rat).SetFrac64(a, b)
|
||||
}
|
||||
|
||||
|
||||
// SetFrac sets z to a/b and returns z.
|
||||
func (z *Rat) SetFrac(a, b *Int) *Rat {
|
||||
z.a.Set(a)
|
||||
|
@ -30,7 +33,6 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// SetFrac64 sets z to a/b and returns z.
|
||||
func (z *Rat) SetFrac64(a, b int64) *Rat {
|
||||
z.a.SetInt64(a)
|
||||
|
@ -42,7 +44,6 @@ func (z *Rat) SetFrac64(a, b int64) *Rat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// SetInt sets z to x (by making a copy of x) and returns z.
|
||||
func (z *Rat) SetInt(x *Int) *Rat {
|
||||
z.a.Set(x)
|
||||
|
@ -50,7 +51,6 @@ func (z *Rat) SetInt(x *Int) *Rat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// SetInt64 sets z to x and returns z.
|
||||
func (z *Rat) SetInt64(x int64) *Rat {
|
||||
z.a.SetInt64(x)
|
||||
|
@ -58,7 +58,6 @@ func (z *Rat) SetInt64(x int64) *Rat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Sign returns:
|
||||
//
|
||||
// -1 if x < 0
|
||||
|
@ -69,13 +68,11 @@ func (x *Rat) Sign() int {
|
|||
return x.a.Sign()
|
||||
}
|
||||
|
||||
|
||||
// IsInt returns true if the denominator of x is 1.
|
||||
func (x *Rat) IsInt() bool {
|
||||
return len(x.b) == 1 && x.b[0] == 1
|
||||
}
|
||||
|
||||
|
||||
// Num returns the numerator of z; it may be <= 0.
|
||||
// The result is a reference to z's numerator; it
|
||||
// may change if a new value is assigned to z.
|
||||
|
@ -83,15 +80,13 @@ func (z *Rat) Num() *Int {
|
|||
return &z.a
|
||||
}
|
||||
|
||||
|
||||
// Demom returns the denominator of z; it is always > 0.
|
||||
// Denom returns the denominator of z; it is always > 0.
|
||||
// The result is a reference to z's denominator; it
|
||||
// may change if a new value is assigned to z.
|
||||
func (z *Rat) Denom() *Int {
|
||||
return &Int{false, z.b}
|
||||
}
|
||||
|
||||
|
||||
func gcd(x, y nat) nat {
|
||||
// Euclidean algorithm.
|
||||
var a, b nat
|
||||
|
@ -106,7 +101,6 @@ func gcd(x, y nat) nat {
|
|||
return a
|
||||
}
|
||||
|
||||
|
||||
func (z *Rat) norm() *Rat {
|
||||
f := gcd(z.a.abs, z.b)
|
||||
if len(z.a.abs) == 0 {
|
||||
|
@ -122,7 +116,6 @@ func (z *Rat) norm() *Rat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
func mulNat(x *Int, y nat) *Int {
|
||||
var z Int
|
||||
z.abs = z.abs.mul(x.abs, y)
|
||||
|
@ -130,7 +123,6 @@ func mulNat(x *Int, y nat) *Int {
|
|||
return &z
|
||||
}
|
||||
|
||||
|
||||
// Cmp compares x and y and returns:
|
||||
//
|
||||
// -1 if x < y
|
||||
|
@ -141,7 +133,6 @@ func (x *Rat) Cmp(y *Rat) (r int) {
|
|||
return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b))
|
||||
}
|
||||
|
||||
|
||||
// Abs sets z to |x| (the absolute value of x) and returns z.
|
||||
func (z *Rat) Abs(x *Rat) *Rat {
|
||||
z.a.Abs(&x.a)
|
||||
|
@ -149,7 +140,6 @@ func (z *Rat) Abs(x *Rat) *Rat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Add sets z to the sum x+y and returns z.
|
||||
func (z *Rat) Add(x, y *Rat) *Rat {
|
||||
a1 := mulNat(&x.a, y.b)
|
||||
|
@ -159,7 +149,6 @@ func (z *Rat) Add(x, y *Rat) *Rat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// Sub sets z to the difference x-y and returns z.
|
||||
func (z *Rat) Sub(x, y *Rat) *Rat {
|
||||
a1 := mulNat(&x.a, y.b)
|
||||
|
@ -169,7 +158,6 @@ func (z *Rat) Sub(x, y *Rat) *Rat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// Mul sets z to the product x*y and returns z.
|
||||
func (z *Rat) Mul(x, y *Rat) *Rat {
|
||||
z.a.Mul(&x.a, &y.a)
|
||||
|
@ -177,7 +165,6 @@ func (z *Rat) Mul(x, y *Rat) *Rat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// Quo sets z to the quotient x/y and returns z.
|
||||
// If y == 0, a division-by-zero run-time panic occurs.
|
||||
func (z *Rat) Quo(x, y *Rat) *Rat {
|
||||
|
@ -192,7 +179,6 @@ func (z *Rat) Quo(x, y *Rat) *Rat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// Neg sets z to -x (by making a copy of x if necessary) and returns z.
|
||||
func (z *Rat) Neg(x *Rat) *Rat {
|
||||
z.a.Neg(&x.a)
|
||||
|
@ -200,7 +186,6 @@ func (z *Rat) Neg(x *Rat) *Rat {
|
|||
return z
|
||||
}
|
||||
|
||||
|
||||
// Set sets z to x (by making a copy of x if necessary) and returns z.
|
||||
func (z *Rat) Set(x *Rat) *Rat {
|
||||
z.a.Set(&x.a)
|
||||
|
@ -208,6 +193,25 @@ func (z *Rat) Set(x *Rat) *Rat {
|
|||
return z
|
||||
}
|
||||
|
||||
func ratTok(ch int) bool {
|
||||
return strings.IndexRune("+-/0123456789.eE", ch) >= 0
|
||||
}
|
||||
|
||||
// Scan is a support routine for fmt.Scanner. It accepts the formats
|
||||
// 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent.
|
||||
func (z *Rat) Scan(s fmt.ScanState, ch int) os.Error {
|
||||
tok, err := s.Token(true, ratTok)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.IndexRune("efgEFGv", ch) < 0 {
|
||||
return os.NewError("Rat.Scan: invalid verb")
|
||||
}
|
||||
if _, ok := z.SetString(string(tok)); !ok {
|
||||
return os.NewError("Rat.Scan: invalid syntax")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetString sets z to the value of s and returns z and a boolean indicating
|
||||
// success. s can be given as a fraction "a/b" or as a floating-point number
|
||||
|
@ -225,8 +229,8 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
|
|||
return z, false
|
||||
}
|
||||
s = s[sep+1:]
|
||||
var n int
|
||||
if z.b, _, n = z.b.scan(s, 10); n != len(s) {
|
||||
var err os.Error
|
||||
if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil {
|
||||
return z, false
|
||||
}
|
||||
return z.norm(), true
|
||||
|
@ -267,13 +271,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
|
|||
return z, true
|
||||
}
|
||||
|
||||
|
||||
// String returns a string representation of z in the form "a/b" (even if b == 1).
|
||||
func (z *Rat) String() string {
|
||||
return z.a.String() + "/" + z.b.string(10)
|
||||
return z.a.String() + "/" + z.b.decimalString()
|
||||
}
|
||||
|
||||
|
||||
// RatString returns a string representation of z in the form "a/b" if b != 1,
|
||||
// and in the form "a" if b == 1.
|
||||
func (z *Rat) RatString() string {
|
||||
|
@ -283,12 +285,15 @@ func (z *Rat) RatString() string {
|
|||
return z.String()
|
||||
}
|
||||
|
||||
|
||||
// FloatString returns a string representation of z in decimal form with prec
|
||||
// digits of precision after the decimal point and the last digit rounded.
|
||||
func (z *Rat) FloatString(prec int) string {
|
||||
if z.IsInt() {
|
||||
return z.a.String()
|
||||
s := z.a.String()
|
||||
if prec > 0 {
|
||||
s += "." + strings.Repeat("0", prec)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
q, r := nat{}.div(nat{}, z.a.abs, z.b)
|
||||
|
@ -311,16 +316,56 @@ func (z *Rat) FloatString(prec int) string {
|
|||
}
|
||||
}
|
||||
|
||||
s := q.string(10)
|
||||
s := q.decimalString()
|
||||
if z.a.neg {
|
||||
s = "-" + s
|
||||
}
|
||||
|
||||
if prec > 0 {
|
||||
rs := r.string(10)
|
||||
rs := r.decimalString()
|
||||
leadingZeros := prec - len(rs)
|
||||
s += "." + strings.Repeat("0", leadingZeros) + rs
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Gob codec version. Permits backward-compatible changes to the encoding.
|
||||
const ratGobVersion byte = 1
|
||||
|
||||
// GobEncode implements the gob.GobEncoder interface.
|
||||
func (z *Rat) GobEncode() ([]byte, os.Error) {
|
||||
buf := make([]byte, 1+4+(len(z.a.abs)+len(z.b))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
|
||||
i := z.b.bytes(buf)
|
||||
j := z.a.abs.bytes(buf[0:i])
|
||||
n := i - j
|
||||
if int(uint32(n)) != n {
|
||||
// this should never happen
|
||||
return nil, os.NewError("Rat.GobEncode: numerator too large")
|
||||
}
|
||||
binary.BigEndian.PutUint32(buf[j-4:j], uint32(n))
|
||||
j -= 1 + 4
|
||||
b := ratGobVersion << 1 // make space for sign bit
|
||||
if z.a.neg {
|
||||
b |= 1
|
||||
}
|
||||
buf[j] = b
|
||||
return buf[j:], nil
|
||||
}
|
||||
|
||||
// GobDecode implements the gob.GobDecoder interface.
|
||||
func (z *Rat) GobDecode(buf []byte) os.Error {
|
||||
if len(buf) == 0 {
|
||||
return os.NewError("Rat.GobDecode: no data")
|
||||
}
|
||||
b := buf[0]
|
||||
if b>>1 != ratGobVersion {
|
||||
return os.NewError(fmt.Sprintf("Rat.GobDecode: encoding version %d not supported", b>>1))
|
||||
}
|
||||
const j = 1 + 4
|
||||
i := j + binary.BigEndian.Uint32(buf[j-4:j])
|
||||
z.a.neg = b&1 != 0
|
||||
z.a.abs = z.a.abs.setBytes(buf[j:i])
|
||||
z.b = z.b.setBytes(buf[i:])
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
|
||||
package big
|
||||
|
||||
import "testing"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"gob"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var setStringTests = []struct {
|
||||
in, out string
|
||||
|
@ -52,6 +56,27 @@ func TestRatSetString(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRatScan(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
for i, test := range setStringTests {
|
||||
x := new(Rat)
|
||||
buf.Reset()
|
||||
buf.WriteString(test.in)
|
||||
|
||||
_, err := fmt.Fscanf(&buf, "%v", x)
|
||||
if err == nil != test.ok {
|
||||
if test.ok {
|
||||
t.Errorf("#%d error: %s", i, err.String())
|
||||
} else {
|
||||
t.Errorf("#%d expected error", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err == nil && x.RatString() != test.out {
|
||||
t.Errorf("#%d got %s want %s", i, x.RatString(), test.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var floatStringTests = []struct {
|
||||
in string
|
||||
|
@ -59,12 +84,13 @@ var floatStringTests = []struct {
|
|||
out string
|
||||
}{
|
||||
{"0", 0, "0"},
|
||||
{"0", 4, "0"},
|
||||
{"0", 4, "0.0000"},
|
||||
{"1", 0, "1"},
|
||||
{"1", 2, "1"},
|
||||
{"1", 2, "1.00"},
|
||||
{"-1", 0, "-1"},
|
||||
{".25", 2, "0.25"},
|
||||
{".25", 1, "0.3"},
|
||||
{".25", 3, "0.250"},
|
||||
{"-1/3", 3, "-0.333"},
|
||||
{"-2/3", 4, "-0.6667"},
|
||||
{"0.96", 1, "1.0"},
|
||||
|
@ -84,7 +110,6 @@ func TestFloatString(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRatSign(t *testing.T) {
|
||||
zero := NewRat(0, 1)
|
||||
for _, a := range setStringTests {
|
||||
|
@ -98,7 +123,6 @@ func TestRatSign(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var ratCmpTests = []struct {
|
||||
rat1, rat2 string
|
||||
out int
|
||||
|
@ -126,7 +150,6 @@ func TestRatCmp(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestIsInt(t *testing.T) {
|
||||
one := NewInt(1)
|
||||
for _, a := range setStringTests {
|
||||
|
@ -140,7 +163,6 @@ func TestIsInt(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRatAbs(t *testing.T) {
|
||||
zero := NewRat(0, 1)
|
||||
for _, a := range setStringTests {
|
||||
|
@ -158,7 +180,6 @@ func TestRatAbs(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type ratBinFun func(z, x, y *Rat) *Rat
|
||||
type ratBinArg struct {
|
||||
x, y, z string
|
||||
|
@ -175,7 +196,6 @@ func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var ratBinTests = []struct {
|
||||
x, y string
|
||||
sum, prod string
|
||||
|
@ -232,7 +252,6 @@ func TestRatBin(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestIssue820(t *testing.T) {
|
||||
x := NewRat(3, 1)
|
||||
y := NewRat(2, 1)
|
||||
|
@ -258,7 +277,6 @@ func TestIssue820(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
var setFrac64Tests = []struct {
|
||||
a, b int64
|
||||
out string
|
||||
|
@ -280,3 +298,35 @@ func TestRatSetFrac64Rat(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRatGobEncoding(t *testing.T) {
|
||||
var medium bytes.Buffer
|
||||
enc := gob.NewEncoder(&medium)
|
||||
dec := gob.NewDecoder(&medium)
|
||||
for i, test := range gobEncodingTests {
|
||||
for j := 0; j < 4; j++ {
|
||||
medium.Reset() // empty buffer for each test case (in case of failures)
|
||||
stest := test
|
||||
if j&1 != 0 {
|
||||
// negative numbers
|
||||
stest = "-" + test
|
||||
}
|
||||
if j%2 != 0 {
|
||||
// fractions
|
||||
stest = stest + "." + test
|
||||
}
|
||||
var tx Rat
|
||||
tx.SetString(stest)
|
||||
if err := enc.Encode(&tx); err != nil {
|
||||
t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
|
||||
}
|
||||
var rx Rat
|
||||
if err := dec.Decode(&rx); err != nil {
|
||||
t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
|
||||
}
|
||||
if rx.Cmp(&tx) != 0 {
|
||||
t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,16 +15,17 @@ import (
|
|||
"utf8"
|
||||
)
|
||||
|
||||
|
||||
const (
|
||||
defaultBufSize = 4096
|
||||
)
|
||||
|
||||
// Errors introduced by this package.
|
||||
type Error struct {
|
||||
os.ErrorString
|
||||
ErrorString string
|
||||
}
|
||||
|
||||
func (err *Error) String() string { return err.ErrorString }
|
||||
|
||||
var (
|
||||
ErrInvalidUnreadByte os.Error = &Error{"bufio: invalid use of UnreadByte"}
|
||||
ErrInvalidUnreadRune os.Error = &Error{"bufio: invalid use of UnreadRune"}
|
||||
|
@ -40,7 +41,6 @@ func (b BufSizeError) String() string {
|
|||
return "bufio: bad buffer size " + strconv.Itoa(int(b))
|
||||
}
|
||||
|
||||
|
||||
// Buffered input.
|
||||
|
||||
// Reader implements buffering for an io.Reader object.
|
||||
|
@ -101,6 +101,12 @@ func (b *Reader) fill() {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *Reader) readErr() os.Error {
|
||||
err := b.err
|
||||
b.err = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Peek returns the next n bytes without advancing the reader. The bytes stop
|
||||
// being valid at the next read call. If Peek returns fewer than n bytes, it
|
||||
// also returns an error explaining why the read is short. The error is
|
||||
|
@ -119,7 +125,7 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
|
|||
if m > n {
|
||||
m = n
|
||||
}
|
||||
err := b.err
|
||||
err := b.readErr()
|
||||
if m < n && err == nil {
|
||||
err = ErrBufferFull
|
||||
}
|
||||
|
@ -134,11 +140,11 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
|
|||
func (b *Reader) Read(p []byte) (n int, err os.Error) {
|
||||
n = len(p)
|
||||
if n == 0 {
|
||||
return 0, b.err
|
||||
return 0, b.readErr()
|
||||
}
|
||||
if b.w == b.r {
|
||||
if b.err != nil {
|
||||
return 0, b.err
|
||||
return 0, b.readErr()
|
||||
}
|
||||
if len(p) >= len(b.buf) {
|
||||
// Large read, empty buffer.
|
||||
|
@ -148,11 +154,11 @@ func (b *Reader) Read(p []byte) (n int, err os.Error) {
|
|||
b.lastByte = int(p[n-1])
|
||||
b.lastRuneSize = -1
|
||||
}
|
||||
return n, b.err
|
||||
return n, b.readErr()
|
||||
}
|
||||
b.fill()
|
||||
if b.w == b.r {
|
||||
return 0, b.err
|
||||
return 0, b.readErr()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -172,7 +178,7 @@ func (b *Reader) ReadByte() (c byte, err os.Error) {
|
|||
b.lastRuneSize = -1
|
||||
for b.w == b.r {
|
||||
if b.err != nil {
|
||||
return 0, b.err
|
||||
return 0, b.readErr()
|
||||
}
|
||||
b.fill()
|
||||
}
|
||||
|
@ -208,7 +214,7 @@ func (b *Reader) ReadRune() (rune int, size int, err os.Error) {
|
|||
}
|
||||
b.lastRuneSize = -1
|
||||
if b.r == b.w {
|
||||
return 0, 0, b.err
|
||||
return 0, 0, b.readErr()
|
||||
}
|
||||
rune, size = int(b.buf[b.r]), 1
|
||||
if rune >= 0x80 {
|
||||
|
@ -260,7 +266,7 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) {
|
|||
if b.err != nil {
|
||||
line := b.buf[b.r:b.w]
|
||||
b.r = b.w
|
||||
return line, b.err
|
||||
return line, b.readErr()
|
||||
}
|
||||
|
||||
n := b.Buffered()
|
||||
|
@ -367,7 +373,6 @@ func (b *Reader) ReadString(delim byte) (line string, err os.Error) {
|
|||
return string(bytes), e
|
||||
}
|
||||
|
||||
|
||||
// buffered output
|
||||
|
||||
// Writer implements buffering for an io.Writer object.
|
||||
|
|
|
@ -53,11 +53,12 @@ func readBytes(buf *Reader) string {
|
|||
if e == os.EOF {
|
||||
break
|
||||
}
|
||||
if e != nil {
|
||||
if e == nil {
|
||||
b[nb] = c
|
||||
nb++
|
||||
} else if e != iotest.ErrTimeout {
|
||||
panic("Data: " + e.String())
|
||||
}
|
||||
b[nb] = c
|
||||
nb++
|
||||
}
|
||||
return string(b[0:nb])
|
||||
}
|
||||
|
@ -75,7 +76,6 @@ func TestReaderSimple(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type readMaker struct {
|
||||
name string
|
||||
fn func(io.Reader) io.Reader
|
||||
|
@ -86,6 +86,7 @@ var readMakers = []readMaker{
|
|||
{"byte", iotest.OneByteReader},
|
||||
{"half", iotest.HalfReader},
|
||||
{"data+err", iotest.DataErrReader},
|
||||
{"timeout", iotest.TimeoutReader},
|
||||
}
|
||||
|
||||
// Call ReadString (which ends up calling everything else)
|
||||
|
@ -97,7 +98,7 @@ func readLines(b *Reader) string {
|
|||
if e == os.EOF {
|
||||
break
|
||||
}
|
||||
if e != nil {
|
||||
if e != nil && e != iotest.ErrTimeout {
|
||||
panic("GetLines: " + e.String())
|
||||
}
|
||||
s += s1
|
||||
|
|
135
libgo/go/builtin/builtin.go
Normal file
135
libgo/go/builtin/builtin.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package builtin provides documentation for Go's built-in functions.
|
||||
The functions documented here are not actually in package builtin
|
||||
but their descriptions here allow godoc to present documentation
|
||||
for the language's special functions.
|
||||
*/
|
||||
package builtin
|
||||
|
||||
// Type is here for the purposes of documentation only. It is a stand-in
|
||||
// for any Go type, but represents the same type for any given function
|
||||
// invocation.
|
||||
type Type int
|
||||
|
||||
// IntegerType is here for the purposes of documentation only. It is a stand-in
|
||||
// for any integer type: int, uint, int8 etc.
|
||||
type IntegerType int
|
||||
|
||||
// FloatType is here for the purposes of documentation only. It is a stand-in
|
||||
// for either float type: float32 or float64.
|
||||
type FloatType int
|
||||
|
||||
// ComplexType is here for the purposes of documentation only. It is a
|
||||
// stand-in for either complex type: complex64 or complex128.
|
||||
type ComplexType int
|
||||
|
||||
// The append built-in function appends elements to the end of a slice. If
|
||||
// it has sufficient capacity, the destination is resliced to accommodate the
|
||||
// new elements. If it does not, a new underlying array will be allocated.
|
||||
// Append returns the updated slice. It is therefore necessary to store the
|
||||
// result of append, often in the variable holding the slice itself:
|
||||
// slice = append(slice, elem1, elem2)
|
||||
// slice = append(slice, anotherSlice...)
|
||||
func append(slice []Type, elems ...Type) []Type
|
||||
|
||||
// The copy built-in function copies elements from a source slice into a
|
||||
// destination slice. (As a special case, it also will copy bytes from a
|
||||
// string to a slice of bytes.) The source and destination may overlap. Copy
|
||||
// returns the number of elements copied, which will be the minimum of
|
||||
// len(src) and len(dst).
|
||||
func copy(dst, src []Type) int
|
||||
|
||||
// The len built-in function returns the length of v, according to its type:
|
||||
// Array: the number of elements in v.
|
||||
// Pointer to array: the number of elements in *v (even if v is nil).
|
||||
// Slice, or map: the number of elements in v; if v is nil, len(v) is zero.
|
||||
// String: the number of bytes in v.
|
||||
// Channel: the number of elements queued (unread) in the channel buffer;
|
||||
// if v is nil, len(v) is zero.
|
||||
func len(v Type) int
|
||||
|
||||
// The cap built-in function returns the capacity of v, according to its type:
|
||||
// Array: the number of elements in v (same as len(v)).
|
||||
// Pointer to array: the number of elements in *v (same as len(v)).
|
||||
// Slice: the maximum length the slice can reach when resliced;
|
||||
// if v is nil, cap(v) is zero.
|
||||
// Channel: the channel buffer capacity, in units of elements;
|
||||
// if v is nil, cap(v) is zero.
|
||||
func cap(v Type) int
|
||||
|
||||
// The make built-in function allocates and initializes an object of type
|
||||
// slice, map, or chan (only). Like new, the first argument is a type, not a
|
||||
// value. Unlike new, make's return type is the same as the type of its
|
||||
// argument, not a pointer to it. The specification of the result depends on
|
||||
// the type:
|
||||
// Slice: The size specifies the length. The capacity of the slice is
|
||||
// equal to its length. A second integer argument may be provided to
|
||||
// specify a different capacity; it must be no smaller than the
|
||||
// length, so make([]int, 0, 10) allocates a slice of length 0 and
|
||||
// capacity 10.
|
||||
// Map: An initial allocation is made according to the size but the
|
||||
// resulting map has length 0. The size may be omitted, in which case
|
||||
// a small starting size is allocated.
|
||||
// Channel: The channel's buffer is initialized with the specified
|
||||
// buffer capacity. If zero, or the size is omitted, the channel is
|
||||
// unbuffered.
|
||||
func make(Type, size IntegerType) Type
|
||||
|
||||
// The new built-in function allocates memory. The first argument is a type,
|
||||
// not a value, and the value returned is a pointer to a newly
|
||||
// allocated zero value of that type.
|
||||
func new(Type) *Type
|
||||
|
||||
// The complex built-in function constructs a complex value from two
|
||||
// floating-point values. The real and imaginary parts must be of the same
|
||||
// size, either float32 or float64 (or assignable to them), and the return
|
||||
// value will be the corresponding complex type (complex64 for float32,
|
||||
// complex128 for float64).
|
||||
func complex(r, i FloatType) ComplexType
|
||||
|
||||
// The real built-in function returns the real part of the complex number c.
|
||||
// The return value will be floating point type corresponding to the type of c.
|
||||
func real(c ComplexType) FloatType
|
||||
|
||||
// The imaginary built-in function returns the imaginary part of the complex
|
||||
// number c. The return value will be floating point type corresponding to
|
||||
// the type of c.
|
||||
func imag(c ComplexType) FloatType
|
||||
|
||||
// The close built-in function closes a channel, which must be either
|
||||
// bidirectional or send-only. It should be executed only by the sender,
|
||||
// never the receiver, and has the effect of shutting down the channel after
|
||||
// the last sent value is received. After the last value has been received
|
||||
// from a closed channel c, any receive from c will succeed without
|
||||
// blocking, returning the zero value for the channel element. The form
|
||||
// x, ok := <-c
|
||||
// will also set ok to false for a closed channel.
|
||||
func close(c chan<- Type)
|
||||
|
||||
// The panic built-in function stops normal execution of the current
|
||||
// goroutine. When a function F calls panic, normal execution of F stops
|
||||
// immediately. Any functions whose execution was deferred by F are run in
|
||||
// the usual way, and then F returns to its caller. To the caller G, the
|
||||
// invocation of F then behaves like a call to panic, terminating G's
|
||||
// execution and running any deferred functions. This continues until all
|
||||
// functions in the executing goroutine have stopped, in reverse order. At
|
||||
// that point, the program is terminated and the error condition is reported,
|
||||
// including the value of the argument to panic. This termination sequence
|
||||
// is called panicking and can be controlled by the built-in function
|
||||
// recover.
|
||||
func panic(v interface{})
|
||||
|
||||
// The recover built-in function allows a program to manage behavior of a
|
||||
// panicking goroutine. Executing a call to recover inside a deferred
|
||||
// function (but not any function called by it) stops the panicking sequence
|
||||
// by restoring normal execution and retrieves the error value passed to the
|
||||
// call of panic. If recover is called outside the deferred function it will
|
||||
// not stop a panicking sequence. In this case, or when the goroutine is not
|
||||
// panicking, or if the argument supplied to panic was nil, recover returns
|
||||
// nil. Thus the return value from recover reports whether the goroutine is
|
||||
// panicking.
|
||||
func recover() interface{}
|
|
@ -280,7 +280,7 @@ func (b *Buffer) ReadRune() (r int, size int, err os.Error) {
|
|||
// from any read operation.)
|
||||
func (b *Buffer) UnreadRune() os.Error {
|
||||
if b.lastRead != opReadRune {
|
||||
return os.ErrorString("bytes.Buffer: UnreadRune: previous operation was not ReadRune")
|
||||
return os.NewError("bytes.Buffer: UnreadRune: previous operation was not ReadRune")
|
||||
}
|
||||
b.lastRead = opInvalid
|
||||
if b.off > 0 {
|
||||
|
@ -295,7 +295,7 @@ func (b *Buffer) UnreadRune() os.Error {
|
|||
// returns an error.
|
||||
func (b *Buffer) UnreadByte() os.Error {
|
||||
if b.lastRead != opReadRune && b.lastRead != opRead {
|
||||
return os.ErrorString("bytes.Buffer: UnreadByte: previous operation was not a read")
|
||||
return os.NewError("bytes.Buffer: UnreadByte: previous operation was not a read")
|
||||
}
|
||||
b.lastRead = opInvalid
|
||||
if b.off > 0 {
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"utf8"
|
||||
)
|
||||
|
||||
|
||||
const N = 10000 // make this bigger for a larger (and slower) test
|
||||
var data string // test data for write tests
|
||||
var bytes []byte // test data; same as data but as a slice.
|
||||
|
@ -47,7 +46,6 @@ func check(t *testing.T, testname string, buf *Buffer, s string) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Fill buf through n writes of string fus.
|
||||
// The initial contents of buf corresponds to the string s;
|
||||
// the result is the final contents of buf returned as a string.
|
||||
|
@ -67,7 +65,6 @@ func fillString(t *testing.T, testname string, buf *Buffer, s string, n int, fus
|
|||
return s
|
||||
}
|
||||
|
||||
|
||||
// Fill buf through n writes of byte slice fub.
|
||||
// The initial contents of buf corresponds to the string s;
|
||||
// the result is the final contents of buf returned as a string.
|
||||
|
@ -87,19 +84,16 @@ func fillBytes(t *testing.T, testname string, buf *Buffer, s string, n int, fub
|
|||
return s
|
||||
}
|
||||
|
||||
|
||||
func TestNewBuffer(t *testing.T) {
|
||||
buf := NewBuffer(bytes)
|
||||
check(t, "NewBuffer", buf, data)
|
||||
}
|
||||
|
||||
|
||||
func TestNewBufferString(t *testing.T) {
|
||||
buf := NewBufferString(data)
|
||||
check(t, "NewBufferString", buf, data)
|
||||
}
|
||||
|
||||
|
||||
// Empty buf through repeated reads into fub.
|
||||
// The initial contents of buf corresponds to the string s.
|
||||
func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
|
||||
|
@ -120,7 +114,6 @@ func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
|
|||
check(t, testname+" (empty 4)", buf, "")
|
||||
}
|
||||
|
||||
|
||||
func TestBasicOperations(t *testing.T) {
|
||||
var buf Buffer
|
||||
|
||||
|
@ -175,7 +168,6 @@ func TestBasicOperations(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestLargeStringWrites(t *testing.T) {
|
||||
var buf Buffer
|
||||
limit := 30
|
||||
|
@ -189,7 +181,6 @@ func TestLargeStringWrites(t *testing.T) {
|
|||
check(t, "TestLargeStringWrites (3)", &buf, "")
|
||||
}
|
||||
|
||||
|
||||
func TestLargeByteWrites(t *testing.T) {
|
||||
var buf Buffer
|
||||
limit := 30
|
||||
|
@ -203,7 +194,6 @@ func TestLargeByteWrites(t *testing.T) {
|
|||
check(t, "TestLargeByteWrites (3)", &buf, "")
|
||||
}
|
||||
|
||||
|
||||
func TestLargeStringReads(t *testing.T) {
|
||||
var buf Buffer
|
||||
for i := 3; i < 30; i += 3 {
|
||||
|
@ -213,7 +203,6 @@ func TestLargeStringReads(t *testing.T) {
|
|||
check(t, "TestLargeStringReads (3)", &buf, "")
|
||||
}
|
||||
|
||||
|
||||
func TestLargeByteReads(t *testing.T) {
|
||||
var buf Buffer
|
||||
for i := 3; i < 30; i += 3 {
|
||||
|
@ -223,7 +212,6 @@ func TestLargeByteReads(t *testing.T) {
|
|||
check(t, "TestLargeByteReads (3)", &buf, "")
|
||||
}
|
||||
|
||||
|
||||
func TestMixedReadsAndWrites(t *testing.T) {
|
||||
var buf Buffer
|
||||
s := ""
|
||||
|
@ -243,7 +231,6 @@ func TestMixedReadsAndWrites(t *testing.T) {
|
|||
empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len()))
|
||||
}
|
||||
|
||||
|
||||
func TestNil(t *testing.T) {
|
||||
var b *Buffer
|
||||
if b.String() != "<nil>" {
|
||||
|
@ -251,7 +238,6 @@ func TestNil(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestReadFrom(t *testing.T) {
|
||||
var buf Buffer
|
||||
for i := 3; i < 30; i += 3 {
|
||||
|
@ -262,7 +248,6 @@ func TestReadFrom(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestWriteTo(t *testing.T) {
|
||||
var buf Buffer
|
||||
for i := 3; i < 30; i += 3 {
|
||||
|
@ -273,7 +258,6 @@ func TestWriteTo(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRuneIO(t *testing.T) {
|
||||
const NRune = 1000
|
||||
// Built a test array while we write the data
|
||||
|
@ -323,7 +307,6 @@ func TestRuneIO(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestNext(t *testing.T) {
|
||||
b := []byte{0, 1, 2, 3, 4}
|
||||
tmp := make([]byte, 5)
|
||||
|
|
|
@ -212,24 +212,38 @@ func genSplit(s, sep []byte, sepSave, n int) [][]byte {
|
|||
return a[0 : na+1]
|
||||
}
|
||||
|
||||
// Split slices s into subslices separated by sep and returns a slice of
|
||||
// SplitN slices s into subslices separated by sep and returns a slice of
|
||||
// the subslices between those separators.
|
||||
// If sep is empty, SplitN splits after each UTF-8 sequence.
|
||||
// The count determines the number of subslices to return:
|
||||
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
|
||||
// n == 0: the result is nil (zero subslices)
|
||||
// n < 0: all subslices
|
||||
func SplitN(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
|
||||
|
||||
// SplitAfterN slices s into subslices after each instance of sep and
|
||||
// returns a slice of those subslices.
|
||||
// If sep is empty, SplitAfterN splits after each UTF-8 sequence.
|
||||
// The count determines the number of subslices to return:
|
||||
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
|
||||
// n == 0: the result is nil (zero subslices)
|
||||
// n < 0: all subslices
|
||||
func SplitAfterN(s, sep []byte, n int) [][]byte {
|
||||
return genSplit(s, sep, len(sep), n)
|
||||
}
|
||||
|
||||
// Split slices s into all subslices separated by sep and returns a slice of
|
||||
// the subslices between those separators.
|
||||
// If sep is empty, Split splits after each UTF-8 sequence.
|
||||
// The count determines the number of subslices to return:
|
||||
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
|
||||
// n == 0: the result is nil (zero subslices)
|
||||
// n < 0: all subslices
|
||||
func Split(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
|
||||
// It is equivalent to SplitN with a count of -1.
|
||||
func Split(s, sep []byte) [][]byte { return genSplit(s, sep, 0, -1) }
|
||||
|
||||
// SplitAfter slices s into subslices after each instance of sep and
|
||||
// SplitAfter slices s into all subslices after each instance of sep and
|
||||
// returns a slice of those subslices.
|
||||
// If sep is empty, Split splits after each UTF-8 sequence.
|
||||
// The count determines the number of subslices to return:
|
||||
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
|
||||
// n == 0: the result is nil (zero subslices)
|
||||
// n < 0: all subslices
|
||||
func SplitAfter(s, sep []byte, n int) [][]byte {
|
||||
return genSplit(s, sep, len(sep), n)
|
||||
// If sep is empty, SplitAfter splits after each UTF-8 sequence.
|
||||
// It is equivalent to SplitAfterN with a count of -1.
|
||||
func SplitAfter(s, sep []byte) [][]byte {
|
||||
return genSplit(s, sep, len(sep), -1)
|
||||
}
|
||||
|
||||
// Fields splits the array s around each instance of one or more consecutive white space
|
||||
|
@ -384,7 +398,6 @@ func ToTitleSpecial(_case unicode.SpecialCase, s []byte) []byte {
|
|||
return Map(func(r int) int { return _case.ToTitle(r) }, s)
|
||||
}
|
||||
|
||||
|
||||
// isSeparator reports whether the rune could mark a word boundary.
|
||||
// TODO: update when package unicode captures more of the properties.
|
||||
func isSeparator(rune int) bool {
|
||||
|
|
|
@ -6,6 +6,7 @@ package bytes_test
|
|||
|
||||
import (
|
||||
. "bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
"unicode"
|
||||
"utf8"
|
||||
|
@ -315,7 +316,7 @@ var explodetests = []ExplodeTest{
|
|||
|
||||
func TestExplode(t *testing.T) {
|
||||
for _, tt := range explodetests {
|
||||
a := Split([]byte(tt.s), nil, tt.n)
|
||||
a := SplitN([]byte(tt.s), nil, tt.n)
|
||||
result := arrayOfString(a)
|
||||
if !eq(result, tt.a) {
|
||||
t.Errorf(`Explode("%s", %d) = %v; want %v`, tt.s, tt.n, result, tt.a)
|
||||
|
@ -328,7 +329,6 @@ func TestExplode(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type SplitTest struct {
|
||||
s string
|
||||
sep string
|
||||
|
@ -354,7 +354,7 @@ var splittests = []SplitTest{
|
|||
|
||||
func TestSplit(t *testing.T) {
|
||||
for _, tt := range splittests {
|
||||
a := Split([]byte(tt.s), []byte(tt.sep), tt.n)
|
||||
a := SplitN([]byte(tt.s), []byte(tt.sep), tt.n)
|
||||
result := arrayOfString(a)
|
||||
if !eq(result, tt.a) {
|
||||
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
|
||||
|
@ -367,6 +367,12 @@ func TestSplit(t *testing.T) {
|
|||
if string(s) != tt.s {
|
||||
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
|
||||
}
|
||||
if tt.n < 0 {
|
||||
b := Split([]byte(tt.s), []byte(tt.sep))
|
||||
if !reflect.DeepEqual(a, b) {
|
||||
t.Errorf("Split disagrees withSplitN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -388,7 +394,7 @@ var splitaftertests = []SplitTest{
|
|||
|
||||
func TestSplitAfter(t *testing.T) {
|
||||
for _, tt := range splitaftertests {
|
||||
a := SplitAfter([]byte(tt.s), []byte(tt.sep), tt.n)
|
||||
a := SplitAfterN([]byte(tt.s), []byte(tt.sep), tt.n)
|
||||
result := arrayOfString(a)
|
||||
if !eq(result, tt.a) {
|
||||
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
|
||||
|
@ -398,6 +404,12 @@ func TestSplitAfter(t *testing.T) {
|
|||
if string(s) != tt.s {
|
||||
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
|
||||
}
|
||||
if tt.n < 0 {
|
||||
b := SplitAfter([]byte(tt.s), []byte(tt.sep))
|
||||
if !reflect.DeepEqual(a, b) {
|
||||
t.Errorf("SplitAfter disagrees withSplitAfterN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -649,7 +661,6 @@ func TestRunes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
type TrimTest struct {
|
||||
f func([]byte, string) []byte
|
||||
in, cutset, out string
|
||||
|
|
|
@ -284,7 +284,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
|
|||
repeat := 0
|
||||
repeat_power := 0
|
||||
|
||||
// The `C' array (used by the inverse BWT) needs to be zero initialised.
|
||||
// The `C' array (used by the inverse BWT) needs to be zero initialized.
|
||||
for i := range bz2.c {
|
||||
bz2.c[i] = 0
|
||||
}
|
||||
|
@ -330,7 +330,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
|
|||
|
||||
if int(v) == numSymbols-1 {
|
||||
// This is the EOF symbol. Because it's always at the
|
||||
// end of the move-to-front list, and nevers gets moved
|
||||
// end of the move-to-front list, and never gets moved
|
||||
// to the front, it has this unique value.
|
||||
break
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ func newHuffmanTree(lengths []uint8) (huffmanTree, os.Error) {
|
|||
// each symbol (consider reflecting a tree down the middle, for
|
||||
// example). Since the code length assignments determine the
|
||||
// efficiency of the tree, each of these trees is equally good. In
|
||||
// order to minimise the amount of information needed to build a tree
|
||||
// order to minimize the amount of information needed to build a tree
|
||||
// bzip2 uses a canonical tree so that it can be reconstructed given
|
||||
// only the code length assignments.
|
||||
|
||||
|
|
|
@ -11,16 +11,18 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
NoCompression = 0
|
||||
BestSpeed = 1
|
||||
fastCompression = 3
|
||||
BestCompression = 9
|
||||
DefaultCompression = -1
|
||||
logMaxOffsetSize = 15 // Standard DEFLATE
|
||||
wideLogMaxOffsetSize = 22 // Wide DEFLATE
|
||||
minMatchLength = 3 // The smallest match that the compressor looks for
|
||||
maxMatchLength = 258 // The longest match for the compressor
|
||||
minOffsetSize = 1 // The shortest offset that makes any sence
|
||||
NoCompression = 0
|
||||
BestSpeed = 1
|
||||
fastCompression = 3
|
||||
BestCompression = 9
|
||||
DefaultCompression = -1
|
||||
logWindowSize = 15
|
||||
windowSize = 1 << logWindowSize
|
||||
windowMask = windowSize - 1
|
||||
logMaxOffsetSize = 15 // Standard DEFLATE
|
||||
minMatchLength = 3 // The smallest match that the compressor looks for
|
||||
maxMatchLength = 258 // The longest match for the compressor
|
||||
minOffsetSize = 1 // The shortest offset that makes any sence
|
||||
|
||||
// The maximum number of tokens we put into a single flat block, just too
|
||||
// stop things from getting too large.
|
||||
|
@ -32,22 +34,6 @@ const (
|
|||
hashShift = (hashBits + minMatchLength - 1) / minMatchLength
|
||||
)
|
||||
|
||||
type syncPipeReader struct {
|
||||
*io.PipeReader
|
||||
closeChan chan bool
|
||||
}
|
||||
|
||||
func (sr *syncPipeReader) CloseWithError(err os.Error) os.Error {
|
||||
retErr := sr.PipeReader.CloseWithError(err)
|
||||
sr.closeChan <- true // finish writer close
|
||||
return retErr
|
||||
}
|
||||
|
||||
type syncPipeWriter struct {
|
||||
*io.PipeWriter
|
||||
closeChan chan bool
|
||||
}
|
||||
|
||||
type compressionLevel struct {
|
||||
good, lazy, nice, chain, fastSkipHashing int
|
||||
}
|
||||
|
@ -68,105 +54,73 @@ var levels = []compressionLevel{
|
|||
{32, 258, 258, 4096, math.MaxInt32},
|
||||
}
|
||||
|
||||
func (sw *syncPipeWriter) Close() os.Error {
|
||||
err := sw.PipeWriter.Close()
|
||||
<-sw.closeChan // wait for reader close
|
||||
return err
|
||||
}
|
||||
|
||||
func syncPipe() (*syncPipeReader, *syncPipeWriter) {
|
||||
r, w := io.Pipe()
|
||||
sr := &syncPipeReader{r, make(chan bool, 1)}
|
||||
sw := &syncPipeWriter{w, sr.closeChan}
|
||||
return sr, sw
|
||||
}
|
||||
|
||||
type compressor struct {
|
||||
level int
|
||||
logWindowSize uint
|
||||
w *huffmanBitWriter
|
||||
r io.Reader
|
||||
// (1 << logWindowSize) - 1.
|
||||
windowMask int
|
||||
compressionLevel
|
||||
|
||||
eof bool // has eof been reached on input?
|
||||
sync bool // writer wants to flush
|
||||
syncChan chan os.Error
|
||||
w *huffmanBitWriter
|
||||
|
||||
// compression algorithm
|
||||
fill func(*compressor, []byte) int // copy data to window
|
||||
step func(*compressor) // process window
|
||||
sync bool // requesting flush
|
||||
|
||||
// Input hash chains
|
||||
// hashHead[hashValue] contains the largest inputIndex with the specified hash value
|
||||
hashHead []int
|
||||
|
||||
// If hashHead[hashValue] is within the current window, then
|
||||
// hashPrev[hashHead[hashValue] & windowMask] contains the previous index
|
||||
// with the same hash value.
|
||||
hashPrev []int
|
||||
chainHead int
|
||||
hashHead []int
|
||||
hashPrev []int
|
||||
|
||||
// If we find a match of length >= niceMatch, then we don't bother searching
|
||||
// any further.
|
||||
niceMatch int
|
||||
// input window: unprocessed data is window[index:windowEnd]
|
||||
index int
|
||||
window []byte
|
||||
windowEnd int
|
||||
blockStart int // window index where current tokens start
|
||||
byteAvailable bool // if true, still need to process window[index-1].
|
||||
|
||||
// If we find a match of length >= goodMatch, we only do a half-hearted
|
||||
// effort at doing lazy matching starting at the next character
|
||||
goodMatch int
|
||||
// queued output tokens: tokens[:ti]
|
||||
tokens []token
|
||||
ti int
|
||||
|
||||
// The maximum number of chains we look at when finding a match
|
||||
maxChainLength int
|
||||
|
||||
// The sliding window we use for matching
|
||||
window []byte
|
||||
|
||||
// The index just past the last valid character
|
||||
windowEnd int
|
||||
|
||||
// index in "window" at which current block starts
|
||||
blockStart int
|
||||
// deflate state
|
||||
length int
|
||||
offset int
|
||||
hash int
|
||||
maxInsertIndex int
|
||||
err os.Error
|
||||
}
|
||||
|
||||
func (d *compressor) flush() os.Error {
|
||||
d.w.flush()
|
||||
return d.w.err
|
||||
}
|
||||
|
||||
func (d *compressor) fillWindow(index int) (int, os.Error) {
|
||||
if d.sync {
|
||||
return index, nil
|
||||
}
|
||||
wSize := d.windowMask + 1
|
||||
if index >= wSize+wSize-(minMatchLength+maxMatchLength) {
|
||||
// shift the window by wSize
|
||||
copy(d.window, d.window[wSize:2*wSize])
|
||||
index -= wSize
|
||||
d.windowEnd -= wSize
|
||||
if d.blockStart >= wSize {
|
||||
d.blockStart -= wSize
|
||||
func (d *compressor) fillDeflate(b []byte) int {
|
||||
if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
|
||||
// shift the window by windowSize
|
||||
copy(d.window, d.window[windowSize:2*windowSize])
|
||||
d.index -= windowSize
|
||||
d.windowEnd -= windowSize
|
||||
if d.blockStart >= windowSize {
|
||||
d.blockStart -= windowSize
|
||||
} else {
|
||||
d.blockStart = math.MaxInt32
|
||||
}
|
||||
for i, h := range d.hashHead {
|
||||
v := h - wSize
|
||||
v := h - windowSize
|
||||
if v < -1 {
|
||||
v = -1
|
||||
}
|
||||
d.hashHead[i] = v
|
||||
}
|
||||
for i, h := range d.hashPrev {
|
||||
v := -h - wSize
|
||||
v := -h - windowSize
|
||||
if v < -1 {
|
||||
v = -1
|
||||
}
|
||||
d.hashPrev[i] = v
|
||||
}
|
||||
}
|
||||
count, err := d.r.Read(d.window[d.windowEnd:])
|
||||
d.windowEnd += count
|
||||
if count == 0 && err == nil {
|
||||
d.sync = true
|
||||
}
|
||||
if err == os.EOF {
|
||||
d.eof = true
|
||||
err = nil
|
||||
}
|
||||
return index, err
|
||||
n := copy(d.window[d.windowEnd:], b)
|
||||
d.windowEnd += n
|
||||
return n
|
||||
}
|
||||
|
||||
func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error {
|
||||
|
@ -194,21 +148,21 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
|
|||
|
||||
// We quit when we get a match that's at least nice long
|
||||
nice := len(win) - pos
|
||||
if d.niceMatch < nice {
|
||||
nice = d.niceMatch
|
||||
if d.nice < nice {
|
||||
nice = d.nice
|
||||
}
|
||||
|
||||
// If we've got a match that's good enough, only look in 1/4 the chain.
|
||||
tries := d.maxChainLength
|
||||
tries := d.chain
|
||||
length = prevLength
|
||||
if length >= d.goodMatch {
|
||||
if length >= d.good {
|
||||
tries >>= 2
|
||||
}
|
||||
|
||||
w0 := win[pos]
|
||||
w1 := win[pos+1]
|
||||
wEnd := win[pos+length]
|
||||
minIndex := pos - (d.windowMask + 1)
|
||||
minIndex := pos - windowSize
|
||||
|
||||
for i := prevHead; tries > 0; tries-- {
|
||||
if w0 == win[i] && w1 == win[i+1] && wEnd == win[i+length] {
|
||||
|
@ -233,7 +187,7 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
|
|||
// hashPrev[i & windowMask] has already been overwritten, so stop now.
|
||||
break
|
||||
}
|
||||
if i = d.hashPrev[i&d.windowMask]; i < minIndex || i < 0 {
|
||||
if i = d.hashPrev[i&windowMask]; i < minIndex || i < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -248,234 +202,224 @@ func (d *compressor) writeStoredBlock(buf []byte) os.Error {
|
|||
return d.w.err
|
||||
}
|
||||
|
||||
func (d *compressor) storedDeflate() os.Error {
|
||||
buf := make([]byte, maxStoreBlockSize)
|
||||
for {
|
||||
n, err := d.r.Read(buf)
|
||||
if n == 0 && err == nil {
|
||||
d.sync = true
|
||||
}
|
||||
if n > 0 || d.sync {
|
||||
if err := d.writeStoredBlock(buf[0:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
if d.sync {
|
||||
d.syncChan <- nil
|
||||
d.sync = false
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == os.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func (d *compressor) initDeflate() {
|
||||
d.hashHead = make([]int, hashSize)
|
||||
d.hashPrev = make([]int, windowSize)
|
||||
d.window = make([]byte, 2*windowSize)
|
||||
fillInts(d.hashHead, -1)
|
||||
d.tokens = make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
|
||||
d.length = minMatchLength - 1
|
||||
d.offset = 0
|
||||
d.byteAvailable = false
|
||||
d.index = 0
|
||||
d.ti = 0
|
||||
d.hash = 0
|
||||
d.chainHead = -1
|
||||
}
|
||||
|
||||
func (d *compressor) doDeflate() (err os.Error) {
|
||||
// init
|
||||
d.windowMask = 1<<d.logWindowSize - 1
|
||||
d.hashHead = make([]int, hashSize)
|
||||
d.hashPrev = make([]int, 1<<d.logWindowSize)
|
||||
d.window = make([]byte, 2<<d.logWindowSize)
|
||||
fillInts(d.hashHead, -1)
|
||||
tokens := make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
|
||||
l := levels[d.level]
|
||||
d.goodMatch = l.good
|
||||
d.niceMatch = l.nice
|
||||
d.maxChainLength = l.chain
|
||||
lazyMatch := l.lazy
|
||||
length := minMatchLength - 1
|
||||
offset := 0
|
||||
byteAvailable := false
|
||||
isFastDeflate := l.fastSkipHashing != 0
|
||||
index := 0
|
||||
// run
|
||||
if index, err = d.fillWindow(index); err != nil {
|
||||
func (d *compressor) deflate() {
|
||||
if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync {
|
||||
return
|
||||
}
|
||||
maxOffset := d.windowMask + 1 // (1 << logWindowSize);
|
||||
// only need to change when you refill the window
|
||||
windowEnd := d.windowEnd
|
||||
maxInsertIndex := windowEnd - (minMatchLength - 1)
|
||||
ti := 0
|
||||
|
||||
hash := int(0)
|
||||
if index < maxInsertIndex {
|
||||
hash = int(d.window[index])<<hashShift + int(d.window[index+1])
|
||||
d.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
|
||||
if d.index < d.maxInsertIndex {
|
||||
d.hash = int(d.window[d.index])<<hashShift + int(d.window[d.index+1])
|
||||
}
|
||||
chainHead := -1
|
||||
|
||||
Loop:
|
||||
for {
|
||||
if index > windowEnd {
|
||||
if d.index > d.windowEnd {
|
||||
panic("index > windowEnd")
|
||||
}
|
||||
lookahead := windowEnd - index
|
||||
lookahead := d.windowEnd - d.index
|
||||
if lookahead < minMatchLength+maxMatchLength {
|
||||
if index, err = d.fillWindow(index); err != nil {
|
||||
return
|
||||
if !d.sync {
|
||||
break Loop
|
||||
}
|
||||
windowEnd = d.windowEnd
|
||||
if index > windowEnd {
|
||||
if d.index > d.windowEnd {
|
||||
panic("index > windowEnd")
|
||||
}
|
||||
maxInsertIndex = windowEnd - (minMatchLength - 1)
|
||||
lookahead = windowEnd - index
|
||||
if lookahead == 0 {
|
||||
// Flush current output block if any.
|
||||
if byteAvailable {
|
||||
if d.byteAvailable {
|
||||
// There is still one pending token that needs to be flushed
|
||||
tokens[ti] = literalToken(uint32(d.window[index-1]) & 0xFF)
|
||||
ti++
|
||||
byteAvailable = false
|
||||
d.tokens[d.ti] = literalToken(uint32(d.window[d.index-1]))
|
||||
d.ti++
|
||||
d.byteAvailable = false
|
||||
}
|
||||
if ti > 0 {
|
||||
if err = d.writeBlock(tokens[0:ti], index, false); err != nil {
|
||||
if d.ti > 0 {
|
||||
if d.err = d.writeBlock(d.tokens[0:d.ti], d.index, false); d.err != nil {
|
||||
return
|
||||
}
|
||||
ti = 0
|
||||
}
|
||||
if d.sync {
|
||||
d.w.writeStoredHeader(0, false)
|
||||
d.w.flush()
|
||||
d.syncChan <- d.w.err
|
||||
d.sync = false
|
||||
}
|
||||
|
||||
// If this was only a sync (not at EOF) keep going.
|
||||
if !d.eof {
|
||||
continue
|
||||
d.ti = 0
|
||||
}
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
if index < maxInsertIndex {
|
||||
if d.index < d.maxInsertIndex {
|
||||
// Update the hash
|
||||
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask
|
||||
chainHead = d.hashHead[hash]
|
||||
d.hashPrev[index&d.windowMask] = chainHead
|
||||
d.hashHead[hash] = index
|
||||
d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
|
||||
d.chainHead = d.hashHead[d.hash]
|
||||
d.hashPrev[d.index&windowMask] = d.chainHead
|
||||
d.hashHead[d.hash] = d.index
|
||||
}
|
||||
prevLength := length
|
||||
prevOffset := offset
|
||||
length = minMatchLength - 1
|
||||
offset = 0
|
||||
minIndex := index - maxOffset
|
||||
prevLength := d.length
|
||||
prevOffset := d.offset
|
||||
d.length = minMatchLength - 1
|
||||
d.offset = 0
|
||||
minIndex := d.index - windowSize
|
||||
if minIndex < 0 {
|
||||
minIndex = 0
|
||||
}
|
||||
|
||||
if chainHead >= minIndex &&
|
||||
(isFastDeflate && lookahead > minMatchLength-1 ||
|
||||
!isFastDeflate && lookahead > prevLength && prevLength < lazyMatch) {
|
||||
if newLength, newOffset, ok := d.findMatch(index, chainHead, minMatchLength-1, lookahead); ok {
|
||||
length = newLength
|
||||
offset = newOffset
|
||||
if d.chainHead >= minIndex &&
|
||||
(d.fastSkipHashing != 0 && lookahead > minMatchLength-1 ||
|
||||
d.fastSkipHashing == 0 && lookahead > prevLength && prevLength < d.lazy) {
|
||||
if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead, minMatchLength-1, lookahead); ok {
|
||||
d.length = newLength
|
||||
d.offset = newOffset
|
||||
}
|
||||
}
|
||||
if isFastDeflate && length >= minMatchLength ||
|
||||
!isFastDeflate && prevLength >= minMatchLength && length <= prevLength {
|
||||
if d.fastSkipHashing != 0 && d.length >= minMatchLength ||
|
||||
d.fastSkipHashing == 0 && prevLength >= minMatchLength && d.length <= prevLength {
|
||||
// There was a match at the previous step, and the current match is
|
||||
// not better. Output the previous match.
|
||||
if isFastDeflate {
|
||||
tokens[ti] = matchToken(uint32(length-minMatchLength), uint32(offset-minOffsetSize))
|
||||
if d.fastSkipHashing != 0 {
|
||||
d.tokens[d.ti] = matchToken(uint32(d.length-minMatchLength), uint32(d.offset-minOffsetSize))
|
||||
} else {
|
||||
tokens[ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
|
||||
d.tokens[d.ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
|
||||
}
|
||||
ti++
|
||||
d.ti++
|
||||
// Insert in the hash table all strings up to the end of the match.
|
||||
// index and index-1 are already inserted. If there is not enough
|
||||
// lookahead, the last two strings are not inserted into the hash
|
||||
// table.
|
||||
if length <= l.fastSkipHashing {
|
||||
if d.length <= d.fastSkipHashing {
|
||||
var newIndex int
|
||||
if isFastDeflate {
|
||||
newIndex = index + length
|
||||
if d.fastSkipHashing != 0 {
|
||||
newIndex = d.index + d.length
|
||||
} else {
|
||||
newIndex = prevLength - 1
|
||||
}
|
||||
for index++; index < newIndex; index++ {
|
||||
if index < maxInsertIndex {
|
||||
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask
|
||||
for d.index++; d.index < newIndex; d.index++ {
|
||||
if d.index < d.maxInsertIndex {
|
||||
d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
|
||||
// Get previous value with the same hash.
|
||||
// Our chain should point to the previous value.
|
||||
d.hashPrev[index&d.windowMask] = d.hashHead[hash]
|
||||
d.hashPrev[d.index&windowMask] = d.hashHead[d.hash]
|
||||
// Set the head of the hash chain to us.
|
||||
d.hashHead[hash] = index
|
||||
d.hashHead[d.hash] = d.index
|
||||
}
|
||||
}
|
||||
if !isFastDeflate {
|
||||
byteAvailable = false
|
||||
length = minMatchLength - 1
|
||||
if d.fastSkipHashing == 0 {
|
||||
d.byteAvailable = false
|
||||
d.length = minMatchLength - 1
|
||||
}
|
||||
} else {
|
||||
// For matches this long, we don't bother inserting each individual
|
||||
// item into the table.
|
||||
index += length
|
||||
hash = (int(d.window[index])<<hashShift + int(d.window[index+1]))
|
||||
d.index += d.length
|
||||
d.hash = (int(d.window[d.index])<<hashShift + int(d.window[d.index+1]))
|
||||
}
|
||||
if ti == maxFlateBlockTokens {
|
||||
if d.ti == maxFlateBlockTokens {
|
||||
// The block includes the current character
|
||||
if err = d.writeBlock(tokens, index, false); err != nil {
|
||||
if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
|
||||
return
|
||||
}
|
||||
ti = 0
|
||||
d.ti = 0
|
||||
}
|
||||
} else {
|
||||
if isFastDeflate || byteAvailable {
|
||||
i := index - 1
|
||||
if isFastDeflate {
|
||||
i = index
|
||||
if d.fastSkipHashing != 0 || d.byteAvailable {
|
||||
i := d.index - 1
|
||||
if d.fastSkipHashing != 0 {
|
||||
i = d.index
|
||||
}
|
||||
tokens[ti] = literalToken(uint32(d.window[i]) & 0xFF)
|
||||
ti++
|
||||
if ti == maxFlateBlockTokens {
|
||||
if err = d.writeBlock(tokens, i+1, false); err != nil {
|
||||
d.tokens[d.ti] = literalToken(uint32(d.window[i]))
|
||||
d.ti++
|
||||
if d.ti == maxFlateBlockTokens {
|
||||
if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil {
|
||||
return
|
||||
}
|
||||
ti = 0
|
||||
d.ti = 0
|
||||
}
|
||||
}
|
||||
index++
|
||||
if !isFastDeflate {
|
||||
byteAvailable = true
|
||||
d.index++
|
||||
if d.fastSkipHashing == 0 {
|
||||
d.byteAvailable = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize uint) (err os.Error) {
|
||||
d.r = r
|
||||
func (d *compressor) fillStore(b []byte) int {
|
||||
n := copy(d.window[d.windowEnd:], b)
|
||||
d.windowEnd += n
|
||||
return n
|
||||
}
|
||||
|
||||
func (d *compressor) store() {
|
||||
if d.windowEnd > 0 {
|
||||
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
|
||||
}
|
||||
d.windowEnd = 0
|
||||
}
|
||||
|
||||
func (d *compressor) write(b []byte) (n int, err os.Error) {
|
||||
n = len(b)
|
||||
b = b[d.fill(d, b):]
|
||||
for len(b) > 0 {
|
||||
d.step(d)
|
||||
b = b[d.fill(d, b):]
|
||||
}
|
||||
return n, d.err
|
||||
}
|
||||
|
||||
func (d *compressor) syncFlush() os.Error {
|
||||
d.sync = true
|
||||
d.step(d)
|
||||
if d.err == nil {
|
||||
d.w.writeStoredHeader(0, false)
|
||||
d.w.flush()
|
||||
d.err = d.w.err
|
||||
}
|
||||
d.sync = false
|
||||
return d.err
|
||||
}
|
||||
|
||||
func (d *compressor) init(w io.Writer, level int) (err os.Error) {
|
||||
d.w = newHuffmanBitWriter(w)
|
||||
d.level = level
|
||||
d.logWindowSize = logWindowSize
|
||||
|
||||
switch {
|
||||
case level == NoCompression:
|
||||
err = d.storedDeflate()
|
||||
d.window = make([]byte, maxStoreBlockSize)
|
||||
d.fill = (*compressor).fillStore
|
||||
d.step = (*compressor).store
|
||||
case level == DefaultCompression:
|
||||
d.level = 6
|
||||
level = 6
|
||||
fallthrough
|
||||
case 1 <= level && level <= 9:
|
||||
err = d.doDeflate()
|
||||
d.compressionLevel = levels[level]
|
||||
d.initDeflate()
|
||||
d.fill = (*compressor).fillDeflate
|
||||
d.step = (*compressor).deflate
|
||||
default:
|
||||
return WrongValueError{"level", 0, 9, int32(level)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.sync {
|
||||
d.syncChan <- err
|
||||
d.sync = false
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
func (d *compressor) close() os.Error {
|
||||
d.sync = true
|
||||
d.step(d)
|
||||
if d.err != nil {
|
||||
return d.err
|
||||
}
|
||||
if d.w.writeStoredHeader(0, true); d.w.err != nil {
|
||||
return d.w.err
|
||||
}
|
||||
return d.flush()
|
||||
d.w.flush()
|
||||
return d.w.err
|
||||
}
|
||||
|
||||
// NewWriter returns a new Writer compressing
|
||||
|
@ -486,14 +430,9 @@ func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize
|
|||
// compression; it only adds the necessary DEFLATE framing.
|
||||
func NewWriter(w io.Writer, level int) *Writer {
|
||||
const logWindowSize = logMaxOffsetSize
|
||||
var d compressor
|
||||
d.syncChan = make(chan os.Error, 1)
|
||||
pr, pw := syncPipe()
|
||||
go func() {
|
||||
err := d.compress(pr, w, level, logWindowSize)
|
||||
pr.CloseWithError(err)
|
||||
}()
|
||||
return &Writer{pw, &d}
|
||||
var dw Writer
|
||||
dw.d.init(w, level)
|
||||
return &dw
|
||||
}
|
||||
|
||||
// NewWriterDict is like NewWriter but initializes the new
|
||||
|
@ -526,18 +465,13 @@ func (w *dictWriter) Write(b []byte) (n int, err os.Error) {
|
|||
// A Writer takes data written to it and writes the compressed
|
||||
// form of that data to an underlying writer (see NewWriter).
|
||||
type Writer struct {
|
||||
w *syncPipeWriter
|
||||
d *compressor
|
||||
d compressor
|
||||
}
|
||||
|
||||
// Write writes data to w, which will eventually write the
|
||||
// compressed form of data to its underlying writer.
|
||||
func (w *Writer) Write(data []byte) (n int, err os.Error) {
|
||||
if len(data) == 0 {
|
||||
// no point, and nil interferes with sync
|
||||
return
|
||||
}
|
||||
return w.w.Write(data)
|
||||
return w.d.write(data)
|
||||
}
|
||||
|
||||
// Flush flushes any pending compressed data to the underlying writer.
|
||||
|
@ -550,18 +484,10 @@ func (w *Writer) Write(data []byte) (n int, err os.Error) {
|
|||
func (w *Writer) Flush() os.Error {
|
||||
// For more about flushing:
|
||||
// http://www.bolet.org/~pornin/deflate-flush.html
|
||||
if w.d.sync {
|
||||
panic("compress/flate: double Flush")
|
||||
}
|
||||
_, err := w.w.Write(nil)
|
||||
err1 := <-w.d.syncChan
|
||||
if err == nil {
|
||||
err = err1
|
||||
}
|
||||
return err
|
||||
return w.d.syncFlush()
|
||||
}
|
||||
|
||||
// Close flushes and closes the writer.
|
||||
func (w *Writer) Close() os.Error {
|
||||
return w.w.Close()
|
||||
return w.d.close()
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ var deflateInflateTests = []*deflateInflateTest{
|
|||
&deflateInflateTest{[]byte{0x11, 0x12}},
|
||||
&deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
|
||||
&deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
|
||||
&deflateInflateTest{getLargeDataChunk()},
|
||||
&deflateInflateTest{largeDataChunk()},
|
||||
}
|
||||
|
||||
var reverseBitsTests = []*reverseBitsTest{
|
||||
|
@ -71,23 +71,22 @@ var reverseBitsTests = []*reverseBitsTest{
|
|||
&reverseBitsTest{29, 5, 23},
|
||||
}
|
||||
|
||||
func getLargeDataChunk() []byte {
|
||||
func largeDataChunk() []byte {
|
||||
result := make([]byte, 100000)
|
||||
for i := range result {
|
||||
result[i] = byte(int64(i) * int64(i) & 0xFF)
|
||||
result[i] = byte(i * i & 0xFF)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestDeflate(t *testing.T) {
|
||||
for _, h := range deflateTests {
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
w := NewWriter(buffer, h.level)
|
||||
var buf bytes.Buffer
|
||||
w := NewWriter(&buf, h.level)
|
||||
w.Write(h.in)
|
||||
w.Close()
|
||||
if bytes.Compare(buffer.Bytes(), h.out) != 0 {
|
||||
t.Errorf("buffer is wrong; level = %v, buffer.Bytes() = %v, expected output = %v",
|
||||
h.level, buffer.Bytes(), h.out)
|
||||
if !bytes.Equal(buf.Bytes(), h.out) {
|
||||
t.Errorf("Deflate(%d, %x) = %x, want %x", h.level, h.in, buf.Bytes(), h.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -226,7 +225,6 @@ func testSync(t *testing.T, level int, input []byte, name string) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func testToFromWithLevel(t *testing.T, level int, input []byte, name string) os.Error {
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
w := NewWriter(buffer, level)
|
||||
|
|
|
@ -15,9 +15,6 @@ const (
|
|||
// The largest offset code.
|
||||
offsetCodeCount = 30
|
||||
|
||||
// The largest offset code in the extensions.
|
||||
extendedOffsetCodeCount = 42
|
||||
|
||||
// The special code used to mark the end of a block.
|
||||
endBlockMarker = 256
|
||||
|
||||
|
@ -100,11 +97,11 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
|
|||
return &huffmanBitWriter{
|
||||
w: w,
|
||||
literalFreq: make([]int32, maxLit),
|
||||
offsetFreq: make([]int32, extendedOffsetCodeCount),
|
||||
codegen: make([]uint8, maxLit+extendedOffsetCodeCount+1),
|
||||
offsetFreq: make([]int32, offsetCodeCount),
|
||||
codegen: make([]uint8, maxLit+offsetCodeCount+1),
|
||||
codegenFreq: make([]int32, codegenCodeCount),
|
||||
literalEncoding: newHuffmanEncoder(maxLit),
|
||||
offsetEncoding: newHuffmanEncoder(extendedOffsetCodeCount),
|
||||
offsetEncoding: newHuffmanEncoder(offsetCodeCount),
|
||||
codegenEncoding: newHuffmanEncoder(codegenCodeCount),
|
||||
}
|
||||
}
|
||||
|
@ -185,7 +182,7 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
|
|||
_, w.err = w.w.Write(bytes)
|
||||
}
|
||||
|
||||
// RFC 1951 3.2.7 specifies a special run-length encoding for specifiying
|
||||
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
|
||||
// the literal and offset lengths arrays (which are concatenated into a single
|
||||
// array). This method generates that run-length encoding.
|
||||
//
|
||||
|
@ -279,7 +276,7 @@ func (w *huffmanBitWriter) writeCode(code *huffmanEncoder, literal uint32) {
|
|||
//
|
||||
// numLiterals The number of literals specified in codegen
|
||||
// numOffsets The number of offsets specified in codegen
|
||||
// numCodegens Tne number of codegens used in codegen
|
||||
// numCodegens The number of codegens used in codegen
|
||||
func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) {
|
||||
if w.err != nil {
|
||||
return
|
||||
|
@ -290,13 +287,7 @@ func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, n
|
|||
}
|
||||
w.writeBits(firstBits, 3)
|
||||
w.writeBits(int32(numLiterals-257), 5)
|
||||
if numOffsets > offsetCodeCount {
|
||||
// Extended version of decompressor
|
||||
w.writeBits(int32(offsetCodeCount+((numOffsets-(1+offsetCodeCount))>>3)), 5)
|
||||
w.writeBits(int32((numOffsets-(1+offsetCodeCount))&0x7), 3)
|
||||
} else {
|
||||
w.writeBits(int32(numOffsets-1), 5)
|
||||
}
|
||||
w.writeBits(int32(numOffsets-1), 5)
|
||||
w.writeBits(int32(numCodegens-4), 4)
|
||||
|
||||
for i := 0; i < numCodegens; i++ {
|
||||
|
@ -368,24 +359,17 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
|
|||
tokens = tokens[0 : n+1]
|
||||
tokens[n] = endBlockMarker
|
||||
|
||||
totalLength := -1 // Subtract 1 for endBlock.
|
||||
for _, t := range tokens {
|
||||
switch t.typ() {
|
||||
case literalType:
|
||||
w.literalFreq[t.literal()]++
|
||||
totalLength++
|
||||
break
|
||||
case matchType:
|
||||
length := t.length()
|
||||
offset := t.offset()
|
||||
totalLength += int(length + 3)
|
||||
w.literalFreq[lengthCodesStart+lengthCode(length)]++
|
||||
w.offsetFreq[offsetCode(offset)]++
|
||||
break
|
||||
}
|
||||
}
|
||||
w.literalEncoding.generate(w.literalFreq, 15)
|
||||
w.offsetEncoding.generate(w.offsetFreq, 15)
|
||||
|
||||
// get the number of literals
|
||||
numLiterals := len(w.literalFreq)
|
||||
|
@ -394,15 +378,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
|
|||
}
|
||||
// get the number of offsets
|
||||
numOffsets := len(w.offsetFreq)
|
||||
for numOffsets > 1 && w.offsetFreq[numOffsets-1] == 0 {
|
||||
for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 {
|
||||
numOffsets--
|
||||
}
|
||||
if numOffsets == 0 {
|
||||
// We haven't found a single match. If we want to go with the dynamic encoding,
|
||||
// we should count at least one offset to be sure that the offset huffman tree could be encoded.
|
||||
w.offsetFreq[0] = 1
|
||||
numOffsets = 1
|
||||
}
|
||||
|
||||
w.literalEncoding.generate(w.literalFreq, 15)
|
||||
w.offsetEncoding.generate(w.offsetFreq, 15)
|
||||
|
||||
storedBytes := 0
|
||||
if input != nil {
|
||||
storedBytes = len(input)
|
||||
}
|
||||
var extraBits int64
|
||||
var storedSize int64
|
||||
var storedSize int64 = math.MaxInt64
|
||||
if storedBytes <= maxStoreBlockSize && input != nil {
|
||||
storedSize = int64((storedBytes + 5) * 8)
|
||||
// We only bother calculating the costs of the extra bits required by
|
||||
|
@ -417,34 +411,29 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
|
|||
// First four offset codes have extra size = 0.
|
||||
extraBits += int64(w.offsetFreq[offsetCode]) * int64(offsetExtraBits[offsetCode])
|
||||
}
|
||||
} else {
|
||||
storedSize = math.MaxInt32
|
||||
}
|
||||
|
||||
// Figure out which generates smaller code, fixed Huffman, dynamic
|
||||
// Huffman, or just storing the data.
|
||||
var fixedSize int64 = math.MaxInt64
|
||||
if numOffsets <= offsetCodeCount {
|
||||
fixedSize = int64(3) +
|
||||
fixedLiteralEncoding.bitLength(w.literalFreq) +
|
||||
fixedOffsetEncoding.bitLength(w.offsetFreq) +
|
||||
extraBits
|
||||
}
|
||||
// Figure out smallest code.
|
||||
// Fixed Huffman baseline.
|
||||
var size = int64(3) +
|
||||
fixedLiteralEncoding.bitLength(w.literalFreq) +
|
||||
fixedOffsetEncoding.bitLength(w.offsetFreq) +
|
||||
extraBits
|
||||
var literalEncoding = fixedLiteralEncoding
|
||||
var offsetEncoding = fixedOffsetEncoding
|
||||
|
||||
// Dynamic Huffman?
|
||||
var numCodegens int
|
||||
|
||||
// Generate codegen and codegenFrequencies, which indicates how to encode
|
||||
// the literalEncoding and the offsetEncoding.
|
||||
w.generateCodegen(numLiterals, numOffsets)
|
||||
w.codegenEncoding.generate(w.codegenFreq, 7)
|
||||
numCodegens := len(w.codegenFreq)
|
||||
numCodegens = len(w.codegenFreq)
|
||||
for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 {
|
||||
numCodegens--
|
||||
}
|
||||
extensionSummand := 0
|
||||
if numOffsets > offsetCodeCount {
|
||||
extensionSummand = 3
|
||||
}
|
||||
dynamicHeader := int64(3+5+5+4+(3*numCodegens)) +
|
||||
// Following line is an extension.
|
||||
int64(extensionSummand) +
|
||||
w.codegenEncoding.bitLength(w.codegenFreq) +
|
||||
int64(extraBits) +
|
||||
int64(w.codegenFreq[16]*2) +
|
||||
|
@ -454,26 +443,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
|
|||
w.literalEncoding.bitLength(w.literalFreq) +
|
||||
w.offsetEncoding.bitLength(w.offsetFreq)
|
||||
|
||||
if storedSize < fixedSize && storedSize < dynamicSize {
|
||||
w.writeStoredHeader(storedBytes, eof)
|
||||
w.writeBytes(input[0:storedBytes])
|
||||
return
|
||||
}
|
||||
var literalEncoding *huffmanEncoder
|
||||
var offsetEncoding *huffmanEncoder
|
||||
|
||||
if fixedSize <= dynamicSize {
|
||||
w.writeFixedHeader(eof)
|
||||
literalEncoding = fixedLiteralEncoding
|
||||
offsetEncoding = fixedOffsetEncoding
|
||||
} else {
|
||||
// Write the header.
|
||||
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
|
||||
if dynamicSize < size {
|
||||
size = dynamicSize
|
||||
literalEncoding = w.literalEncoding
|
||||
offsetEncoding = w.offsetEncoding
|
||||
}
|
||||
|
||||
// Write the tokens.
|
||||
// Stored bytes?
|
||||
if storedSize < size {
|
||||
w.writeStoredHeader(storedBytes, eof)
|
||||
w.writeBytes(input[0:storedBytes])
|
||||
return
|
||||
}
|
||||
|
||||
// Huffman.
|
||||
if literalEncoding == fixedLiteralEncoding {
|
||||
w.writeFixedHeader(eof)
|
||||
} else {
|
||||
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
|
||||
}
|
||||
for _, t := range tokens {
|
||||
switch t.typ() {
|
||||
case literalType:
|
||||
|
|
|
@ -363,7 +363,12 @@ func (s literalNodeSorter) Less(i, j int) bool {
|
|||
func (s literalNodeSorter) Swap(i, j int) { s.a[i], s.a[j] = s.a[j], s.a[i] }
|
||||
|
||||
func sortByFreq(a []literalNode) {
|
||||
s := &literalNodeSorter{a, func(i, j int) bool { return a[i].freq < a[j].freq }}
|
||||
s := &literalNodeSorter{a, func(i, j int) bool {
|
||||
if a[i].freq == a[j].freq {
|
||||
return a[i].literal < a[j].literal
|
||||
}
|
||||
return a[i].freq < a[j].freq
|
||||
}}
|
||||
sort.Sort(s)
|
||||
}
|
||||
|
||||
|
|
|
@ -77,8 +77,6 @@ type huffmanDecoder struct {
|
|||
|
||||
// Initialize Huffman decoding tables from array of code lengths.
|
||||
func (h *huffmanDecoder) init(bits []int) bool {
|
||||
// TODO(rsc): Return false sometimes.
|
||||
|
||||
// Count number of codes of each length,
|
||||
// compute min and max length.
|
||||
var count [maxCodeLen + 1]int
|
||||
|
@ -197,9 +195,8 @@ type Reader interface {
|
|||
|
||||
// Decompress state.
|
||||
type decompressor struct {
|
||||
// Input/output sources.
|
||||
// Input source.
|
||||
r Reader
|
||||
w io.Writer
|
||||
roffset int64
|
||||
woffset int64
|
||||
|
||||
|
@ -222,38 +219,79 @@ type decompressor struct {
|
|||
|
||||
// Temporary buffer (avoids repeated allocation).
|
||||
buf [4]byte
|
||||
|
||||
// Next step in the decompression,
|
||||
// and decompression state.
|
||||
step func(*decompressor)
|
||||
final bool
|
||||
err os.Error
|
||||
toRead []byte
|
||||
hl, hd *huffmanDecoder
|
||||
copyLen int
|
||||
copyDist int
|
||||
}
|
||||
|
||||
func (f *decompressor) inflate() (err os.Error) {
|
||||
final := false
|
||||
for err == nil && !final {
|
||||
for f.nb < 1+2 {
|
||||
if err = f.moreBits(); err != nil {
|
||||
return
|
||||
}
|
||||
func (f *decompressor) nextBlock() {
|
||||
if f.final {
|
||||
if f.hw != f.hp {
|
||||
f.flush((*decompressor).nextBlock)
|
||||
return
|
||||
}
|
||||
final = f.b&1 == 1
|
||||
f.b >>= 1
|
||||
typ := f.b & 3
|
||||
f.b >>= 2
|
||||
f.nb -= 1 + 2
|
||||
switch typ {
|
||||
case 0:
|
||||
err = f.dataBlock()
|
||||
case 1:
|
||||
// compressed, fixed Huffman tables
|
||||
err = f.decodeBlock(&fixedHuffmanDecoder, nil)
|
||||
case 2:
|
||||
// compressed, dynamic Huffman tables
|
||||
if err = f.readHuffman(); err == nil {
|
||||
err = f.decodeBlock(&f.h1, &f.h2)
|
||||
}
|
||||
default:
|
||||
// 3 is reserved.
|
||||
err = CorruptInputError(f.roffset)
|
||||
f.err = os.EOF
|
||||
return
|
||||
}
|
||||
for f.nb < 1+2 {
|
||||
if f.err = f.moreBits(); f.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
f.final = f.b&1 == 1
|
||||
f.b >>= 1
|
||||
typ := f.b & 3
|
||||
f.b >>= 2
|
||||
f.nb -= 1 + 2
|
||||
switch typ {
|
||||
case 0:
|
||||
f.dataBlock()
|
||||
case 1:
|
||||
// compressed, fixed Huffman tables
|
||||
f.hl = &fixedHuffmanDecoder
|
||||
f.hd = nil
|
||||
f.huffmanBlock()
|
||||
case 2:
|
||||
// compressed, dynamic Huffman tables
|
||||
if f.err = f.readHuffman(); f.err != nil {
|
||||
break
|
||||
}
|
||||
f.hl = &f.h1
|
||||
f.hd = &f.h2
|
||||
f.huffmanBlock()
|
||||
default:
|
||||
// 3 is reserved.
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *decompressor) Read(b []byte) (int, os.Error) {
|
||||
for {
|
||||
if len(f.toRead) > 0 {
|
||||
n := copy(b, f.toRead)
|
||||
f.toRead = f.toRead[n:]
|
||||
return n, nil
|
||||
}
|
||||
if f.err != nil {
|
||||
return 0, f.err
|
||||
}
|
||||
f.step(f)
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (f *decompressor) Close() os.Error {
|
||||
if f.err == os.EOF {
|
||||
return nil
|
||||
}
|
||||
return f.err
|
||||
}
|
||||
|
||||
// RFC 1951 section 3.2.7.
|
||||
|
@ -358,11 +396,12 @@ func (f *decompressor) readHuffman() os.Error {
|
|||
// hl and hd are the Huffman states for the lit/length values
|
||||
// and the distance values, respectively. If hd == nil, using the
|
||||
// fixed distance encoding associated with fixed Huffman blocks.
|
||||
func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
||||
func (f *decompressor) huffmanBlock() {
|
||||
for {
|
||||
v, err := f.huffSym(hl)
|
||||
v, err := f.huffSym(f.hl)
|
||||
if err != nil {
|
||||
return err
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
var n uint // number of bits extra
|
||||
var length int
|
||||
|
@ -371,13 +410,15 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
|||
f.hist[f.hp] = byte(v)
|
||||
f.hp++
|
||||
if f.hp == len(f.hist) {
|
||||
if err = f.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
// After the flush, continue this loop.
|
||||
f.flush((*decompressor).huffmanBlock)
|
||||
return
|
||||
}
|
||||
continue
|
||||
case v == 256:
|
||||
return nil
|
||||
// Done with huffman block; read next block.
|
||||
f.step = (*decompressor).nextBlock
|
||||
return
|
||||
// otherwise, reference to older data
|
||||
case v < 265:
|
||||
length = v - (257 - 3)
|
||||
|
@ -404,7 +445,8 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
|||
if n > 0 {
|
||||
for f.nb < n {
|
||||
if err = f.moreBits(); err != nil {
|
||||
return err
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
length += int(f.b & uint32(1<<n-1))
|
||||
|
@ -413,18 +455,20 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
|||
}
|
||||
|
||||
var dist int
|
||||
if hd == nil {
|
||||
if f.hd == nil {
|
||||
for f.nb < 5 {
|
||||
if err = f.moreBits(); err != nil {
|
||||
return err
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
dist = int(reverseByte[(f.b&0x1F)<<3])
|
||||
f.b >>= 5
|
||||
f.nb -= 5
|
||||
} else {
|
||||
if dist, err = f.huffSym(hd); err != nil {
|
||||
return err
|
||||
if dist, err = f.huffSym(f.hd); err != nil {
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -432,14 +476,16 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
|||
case dist < 4:
|
||||
dist++
|
||||
case dist >= 30:
|
||||
return CorruptInputError(f.roffset)
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return
|
||||
default:
|
||||
nb := uint(dist-2) >> 1
|
||||
// have 1 bit in bottom of dist, need nb more.
|
||||
extra := (dist & 1) << nb
|
||||
for f.nb < nb {
|
||||
if err = f.moreBits(); err != nil {
|
||||
return err
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
extra |= int(f.b & uint32(1<<nb-1))
|
||||
|
@ -450,12 +496,14 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
|||
|
||||
// Copy history[-dist:-dist+length] into output.
|
||||
if dist > len(f.hist) {
|
||||
return InternalError("bad history distance")
|
||||
f.err = InternalError("bad history distance")
|
||||
return
|
||||
}
|
||||
|
||||
// No check on length; encoding can be prescient.
|
||||
if !f.hfull && dist > f.hp {
|
||||
return CorruptInputError(f.roffset)
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return
|
||||
}
|
||||
|
||||
p := f.hp - dist
|
||||
|
@ -467,9 +515,11 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
|||
f.hp++
|
||||
p++
|
||||
if f.hp == len(f.hist) {
|
||||
if err = f.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
// After flush continue copying out of history.
|
||||
f.copyLen = length - (i + 1)
|
||||
f.copyDist = dist
|
||||
f.flush((*decompressor).copyHuff)
|
||||
return
|
||||
}
|
||||
if p == len(f.hist) {
|
||||
p = 0
|
||||
|
@ -479,8 +529,33 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
|
|||
panic("unreached")
|
||||
}
|
||||
|
||||
func (f *decompressor) copyHuff() {
|
||||
length := f.copyLen
|
||||
dist := f.copyDist
|
||||
p := f.hp - dist
|
||||
if p < 0 {
|
||||
p += len(f.hist)
|
||||
}
|
||||
for i := 0; i < length; i++ {
|
||||
f.hist[f.hp] = f.hist[p]
|
||||
f.hp++
|
||||
p++
|
||||
if f.hp == len(f.hist) {
|
||||
f.copyLen = length - (i + 1)
|
||||
f.flush((*decompressor).copyHuff)
|
||||
return
|
||||
}
|
||||
if p == len(f.hist) {
|
||||
p = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Continue processing Huffman block.
|
||||
f.huffmanBlock()
|
||||
}
|
||||
|
||||
// Copy a single uncompressed data block from input to output.
|
||||
func (f *decompressor) dataBlock() os.Error {
|
||||
func (f *decompressor) dataBlock() {
|
||||
// Uncompressed.
|
||||
// Discard current half-byte.
|
||||
f.nb = 0
|
||||
|
@ -490,21 +565,30 @@ func (f *decompressor) dataBlock() os.Error {
|
|||
nr, err := io.ReadFull(f.r, f.buf[0:4])
|
||||
f.roffset += int64(nr)
|
||||
if err != nil {
|
||||
return &ReadError{f.roffset, err}
|
||||
f.err = &ReadError{f.roffset, err}
|
||||
return
|
||||
}
|
||||
n := int(f.buf[0]) | int(f.buf[1])<<8
|
||||
nn := int(f.buf[2]) | int(f.buf[3])<<8
|
||||
if uint16(nn) != uint16(^n) {
|
||||
return CorruptInputError(f.roffset)
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
// 0-length block means sync
|
||||
return f.flush()
|
||||
f.flush((*decompressor).nextBlock)
|
||||
return
|
||||
}
|
||||
|
||||
// Read len bytes into history,
|
||||
// writing as history fills.
|
||||
f.copyLen = n
|
||||
f.copyData()
|
||||
}
|
||||
|
||||
func (f *decompressor) copyData() {
|
||||
// Read f.dataLen bytes into history,
|
||||
// pausing for reads as history fills.
|
||||
n := f.copyLen
|
||||
for n > 0 {
|
||||
m := len(f.hist) - f.hp
|
||||
if m > n {
|
||||
|
@ -513,17 +597,18 @@ func (f *decompressor) dataBlock() os.Error {
|
|||
m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m])
|
||||
f.roffset += int64(m)
|
||||
if err != nil {
|
||||
return &ReadError{f.roffset, err}
|
||||
f.err = &ReadError{f.roffset, err}
|
||||
return
|
||||
}
|
||||
n -= m
|
||||
f.hp += m
|
||||
if f.hp == len(f.hist) {
|
||||
if err = f.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
f.copyLen = n
|
||||
f.flush((*decompressor).copyData)
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil
|
||||
f.step = (*decompressor).nextBlock
|
||||
}
|
||||
|
||||
func (f *decompressor) setDict(dict []byte) {
|
||||
|
@ -579,17 +664,8 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) {
|
|||
}
|
||||
|
||||
// Flush any buffered output to the underlying writer.
|
||||
func (f *decompressor) flush() os.Error {
|
||||
if f.hw == f.hp {
|
||||
return nil
|
||||
}
|
||||
n, err := f.w.Write(f.hist[f.hw:f.hp])
|
||||
if n != f.hp-f.hw && err == nil {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
if err != nil {
|
||||
return &WriteError{f.woffset, err}
|
||||
}
|
||||
func (f *decompressor) flush(step func(*decompressor)) {
|
||||
f.toRead = f.hist[f.hw:f.hp]
|
||||
f.woffset += int64(f.hp - f.hw)
|
||||
f.hw = f.hp
|
||||
if f.hp == len(f.hist) {
|
||||
|
@ -597,7 +673,7 @@ func (f *decompressor) flush() os.Error {
|
|||
f.hw = 0
|
||||
f.hfull = true
|
||||
}
|
||||
return nil
|
||||
f.step = step
|
||||
}
|
||||
|
||||
func makeReader(r io.Reader) Reader {
|
||||
|
@ -607,30 +683,15 @@ func makeReader(r io.Reader) Reader {
|
|||
return bufio.NewReader(r)
|
||||
}
|
||||
|
||||
// decompress reads DEFLATE-compressed data from r and writes
|
||||
// the uncompressed data to w.
|
||||
func (f *decompressor) decompress(r io.Reader, w io.Writer) os.Error {
|
||||
f.r = makeReader(r)
|
||||
f.w = w
|
||||
f.woffset = 0
|
||||
if err := f.inflate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewReader returns a new ReadCloser that can be used
|
||||
// to read the uncompressed version of r. It is the caller's
|
||||
// responsibility to call Close on the ReadCloser when
|
||||
// finished reading.
|
||||
func NewReader(r io.Reader) io.ReadCloser {
|
||||
var f decompressor
|
||||
pr, pw := io.Pipe()
|
||||
go func() { pw.CloseWithError(f.decompress(r, pw)) }()
|
||||
return pr
|
||||
f.r = makeReader(r)
|
||||
f.step = (*decompressor).nextBlock
|
||||
return &f
|
||||
}
|
||||
|
||||
// NewReaderDict is like NewReader but initializes the reader
|
||||
|
@ -641,7 +702,7 @@ func NewReader(r io.Reader) io.ReadCloser {
|
|||
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
|
||||
var f decompressor
|
||||
f.setDict(dict)
|
||||
pr, pw := io.Pipe()
|
||||
go func() { pw.CloseWithError(f.decompress(r, pw)) }()
|
||||
return pr
|
||||
f.r = makeReader(r)
|
||||
f.step = (*decompressor).nextBlock
|
||||
return &f
|
||||
}
|
||||
|
|
|
@ -36,8 +36,8 @@ func makeReader(r io.Reader) flate.Reader {
|
|||
return bufio.NewReader(r)
|
||||
}
|
||||
|
||||
var HeaderError os.Error = os.ErrorString("invalid gzip header")
|
||||
var ChecksumError os.Error = os.ErrorString("gzip checksum error")
|
||||
var HeaderError = os.NewError("invalid gzip header")
|
||||
var ChecksumError = os.NewError("gzip checksum error")
|
||||
|
||||
// The gzip file stores a header giving metadata about the compressed file.
|
||||
// That header is exposed as the fields of the Compressor and Decompressor structs.
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
// pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the
|
||||
// writer end and ifunc at the reader end.
|
||||
// writer end and cfunc at the reader end.
|
||||
func pipe(t *testing.T, dfunc func(*Compressor), cfunc func(*Decompressor)) {
|
||||
piper, pipew := io.Pipe()
|
||||
defer piper.Close()
|
||||
|
|
|
@ -32,13 +32,49 @@ const (
|
|||
MSB
|
||||
)
|
||||
|
||||
const (
|
||||
maxWidth = 12
|
||||
decoderInvalidCode = 0xffff
|
||||
flushBuffer = 1 << maxWidth
|
||||
)
|
||||
|
||||
// decoder is the state from which the readXxx method converts a byte
|
||||
// stream into a code stream.
|
||||
type decoder struct {
|
||||
r io.ByteReader
|
||||
bits uint32
|
||||
nBits uint
|
||||
width uint
|
||||
r io.ByteReader
|
||||
bits uint32
|
||||
nBits uint
|
||||
width uint
|
||||
read func(*decoder) (uint16, os.Error) // readLSB or readMSB
|
||||
litWidth int // width in bits of literal codes
|
||||
err os.Error
|
||||
|
||||
// The first 1<<litWidth codes are literal codes.
|
||||
// The next two codes mean clear and EOF.
|
||||
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
|
||||
// with the upper bound incrementing on each code seen.
|
||||
// overflow is the code at which hi overflows the code width.
|
||||
// last is the most recently seen code, or decoderInvalidCode.
|
||||
clear, eof, hi, overflow, last uint16
|
||||
|
||||
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
|
||||
// suffix[c] is the last of these bytes.
|
||||
// prefix[c] is the code for all but the last byte.
|
||||
// This code can either be a literal code or another code in [lo, c).
|
||||
// The c == hi case is a special case.
|
||||
suffix [1 << maxWidth]uint8
|
||||
prefix [1 << maxWidth]uint16
|
||||
|
||||
// output is the temporary output buffer.
|
||||
// Literal codes are accumulated from the start of the buffer.
|
||||
// Non-literal codes decode to a sequence of suffixes that are first
|
||||
// written right-to-left from the end of the buffer before being copied
|
||||
// to the start of the buffer.
|
||||
// It is flushed when it contains >= 1<<maxWidth bytes,
|
||||
// so that there is always room to decode an entire code.
|
||||
output [2 * 1 << maxWidth]byte
|
||||
o int // write index into output
|
||||
toRead []byte // bytes to return from Read
|
||||
}
|
||||
|
||||
// readLSB returns the next code for "Least Significant Bits first" data.
|
||||
|
@ -73,119 +109,113 @@ func (d *decoder) readMSB() (uint16, os.Error) {
|
|||
return code, nil
|
||||
}
|
||||
|
||||
// decode decompresses bytes from r and writes them to pw.
|
||||
// read specifies how to decode bytes into codes.
|
||||
// litWidth is the width in bits of literal codes.
|
||||
func decode(r io.Reader, read func(*decoder) (uint16, os.Error), litWidth int, pw *io.PipeWriter) {
|
||||
br, ok := r.(io.ByteReader)
|
||||
if !ok {
|
||||
br = bufio.NewReader(r)
|
||||
func (d *decoder) Read(b []byte) (int, os.Error) {
|
||||
for {
|
||||
if len(d.toRead) > 0 {
|
||||
n := copy(b, d.toRead)
|
||||
d.toRead = d.toRead[n:]
|
||||
return n, nil
|
||||
}
|
||||
if d.err != nil {
|
||||
return 0, d.err
|
||||
}
|
||||
d.decode()
|
||||
}
|
||||
pw.CloseWithError(decode1(pw, br, read, uint(litWidth)))
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os.Error), litWidth uint) os.Error {
|
||||
const (
|
||||
maxWidth = 12
|
||||
invalidCode = 0xffff
|
||||
)
|
||||
d := decoder{r, 0, 0, 1 + litWidth}
|
||||
w := bufio.NewWriter(pw)
|
||||
// The first 1<<litWidth codes are literal codes.
|
||||
// The next two codes mean clear and EOF.
|
||||
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
|
||||
// with the upper bound incrementing on each code seen.
|
||||
clear := uint16(1) << litWidth
|
||||
eof, hi := clear+1, clear+1
|
||||
// overflow is the code at which hi overflows the code width.
|
||||
overflow := uint16(1) << d.width
|
||||
var (
|
||||
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
|
||||
// suffix[c] is the last of these bytes.
|
||||
// prefix[c] is the code for all but the last byte.
|
||||
// This code can either be a literal code or another code in [lo, c).
|
||||
// The c == hi case is a special case.
|
||||
suffix [1 << maxWidth]uint8
|
||||
prefix [1 << maxWidth]uint16
|
||||
// buf is a scratch buffer for reconstituting the bytes that a code expands to.
|
||||
// Code suffixes are written right-to-left from the end of the buffer.
|
||||
buf [1 << maxWidth]byte
|
||||
)
|
||||
|
||||
// decode decompresses bytes from r and leaves them in d.toRead.
|
||||
// read specifies how to decode bytes into codes.
|
||||
// litWidth is the width in bits of literal codes.
|
||||
func (d *decoder) decode() {
|
||||
// Loop over the code stream, converting codes into decompressed bytes.
|
||||
last := uint16(invalidCode)
|
||||
for {
|
||||
code, err := read(&d)
|
||||
code, err := d.read(d)
|
||||
if err != nil {
|
||||
if err == os.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
d.err = err
|
||||
return
|
||||
}
|
||||
switch {
|
||||
case code < clear:
|
||||
case code < d.clear:
|
||||
// We have a literal code.
|
||||
if err := w.WriteByte(uint8(code)); err != nil {
|
||||
return err
|
||||
}
|
||||
if last != invalidCode {
|
||||
d.output[d.o] = uint8(code)
|
||||
d.o++
|
||||
if d.last != decoderInvalidCode {
|
||||
// Save what the hi code expands to.
|
||||
suffix[hi] = uint8(code)
|
||||
prefix[hi] = last
|
||||
d.suffix[d.hi] = uint8(code)
|
||||
d.prefix[d.hi] = d.last
|
||||
}
|
||||
case code == clear:
|
||||
d.width = 1 + litWidth
|
||||
hi = eof
|
||||
overflow = 1 << d.width
|
||||
last = invalidCode
|
||||
case code == d.clear:
|
||||
d.width = 1 + uint(d.litWidth)
|
||||
d.hi = d.eof
|
||||
d.overflow = 1 << d.width
|
||||
d.last = decoderInvalidCode
|
||||
continue
|
||||
case code == eof:
|
||||
return w.Flush()
|
||||
case code <= hi:
|
||||
c, i := code, len(buf)-1
|
||||
if code == hi {
|
||||
case code == d.eof:
|
||||
d.flush()
|
||||
d.err = os.EOF
|
||||
return
|
||||
case code <= d.hi:
|
||||
c, i := code, len(d.output)-1
|
||||
if code == d.hi {
|
||||
// code == hi is a special case which expands to the last expansion
|
||||
// followed by the head of the last expansion. To find the head, we walk
|
||||
// the prefix chain until we find a literal code.
|
||||
c = last
|
||||
for c >= clear {
|
||||
c = prefix[c]
|
||||
c = d.last
|
||||
for c >= d.clear {
|
||||
c = d.prefix[c]
|
||||
}
|
||||
buf[i] = uint8(c)
|
||||
d.output[i] = uint8(c)
|
||||
i--
|
||||
c = last
|
||||
c = d.last
|
||||
}
|
||||
// Copy the suffix chain into buf and then write that to w.
|
||||
for c >= clear {
|
||||
buf[i] = suffix[c]
|
||||
// Copy the suffix chain into output and then write that to w.
|
||||
for c >= d.clear {
|
||||
d.output[i] = d.suffix[c]
|
||||
i--
|
||||
c = prefix[c]
|
||||
c = d.prefix[c]
|
||||
}
|
||||
buf[i] = uint8(c)
|
||||
if _, err := w.Write(buf[i:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if last != invalidCode {
|
||||
d.output[i] = uint8(c)
|
||||
d.o += copy(d.output[d.o:], d.output[i:])
|
||||
if d.last != decoderInvalidCode {
|
||||
// Save what the hi code expands to.
|
||||
suffix[hi] = uint8(c)
|
||||
prefix[hi] = last
|
||||
d.suffix[d.hi] = uint8(c)
|
||||
d.prefix[d.hi] = d.last
|
||||
}
|
||||
default:
|
||||
return os.NewError("lzw: invalid code")
|
||||
d.err = os.NewError("lzw: invalid code")
|
||||
return
|
||||
}
|
||||
last, hi = code, hi+1
|
||||
if hi >= overflow {
|
||||
d.last, d.hi = code, d.hi+1
|
||||
if d.hi >= d.overflow {
|
||||
if d.width == maxWidth {
|
||||
last = invalidCode
|
||||
continue
|
||||
d.last = decoderInvalidCode
|
||||
} else {
|
||||
d.width++
|
||||
d.overflow <<= 1
|
||||
}
|
||||
d.width++
|
||||
overflow <<= 1
|
||||
}
|
||||
if d.o >= flushBuffer {
|
||||
d.flush()
|
||||
return
|
||||
}
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (d *decoder) flush() {
|
||||
d.toRead = d.output[:d.o]
|
||||
d.o = 0
|
||||
}
|
||||
|
||||
func (d *decoder) Close() os.Error {
|
||||
d.err = os.EINVAL // in case any Reads come along
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewReader creates a new io.ReadCloser that satisfies reads by decompressing
|
||||
// the data read from r.
|
||||
// It is the caller's responsibility to call Close on the ReadCloser when
|
||||
|
@ -193,21 +223,31 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os
|
|||
// The number of bits to use for literal codes, litWidth, must be in the
|
||||
// range [2,8] and is typically 8.
|
||||
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
|
||||
pr, pw := io.Pipe()
|
||||
var read func(*decoder) (uint16, os.Error)
|
||||
d := new(decoder)
|
||||
switch order {
|
||||
case LSB:
|
||||
read = (*decoder).readLSB
|
||||
d.read = (*decoder).readLSB
|
||||
case MSB:
|
||||
read = (*decoder).readMSB
|
||||
d.read = (*decoder).readMSB
|
||||
default:
|
||||
pw.CloseWithError(os.NewError("lzw: unknown order"))
|
||||
return pr
|
||||
d.err = os.NewError("lzw: unknown order")
|
||||
return d
|
||||
}
|
||||
if litWidth < 2 || 8 < litWidth {
|
||||
pw.CloseWithError(fmt.Errorf("lzw: litWidth %d out of range", litWidth))
|
||||
return pr
|
||||
d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
|
||||
return d
|
||||
}
|
||||
go decode(r, read, litWidth, pw)
|
||||
return pr
|
||||
if br, ok := r.(io.ByteReader); ok {
|
||||
d.r = br
|
||||
} else {
|
||||
d.r = bufio.NewReader(r)
|
||||
}
|
||||
d.litWidth = litWidth
|
||||
d.width = 1 + uint(litWidth)
|
||||
d.clear = uint16(1) << uint(litWidth)
|
||||
d.eof, d.hi = d.clear+1, d.clear+1
|
||||
d.overflow = uint16(1) << d.width
|
||||
d.last = decoderInvalidCode
|
||||
|
||||
return d
|
||||
}
|
||||
|
|
|
@ -84,7 +84,7 @@ var lzwTests = []lzwTest{
|
|||
func TestReader(t *testing.T) {
|
||||
b := bytes.NewBuffer(nil)
|
||||
for _, tt := range lzwTests {
|
||||
d := strings.Split(tt.desc, ";", -1)
|
||||
d := strings.Split(tt.desc, ";")
|
||||
var order Order
|
||||
switch d[1] {
|
||||
case "LSB":
|
||||
|
|
|
@ -77,13 +77,13 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
|
|||
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
|
||||
return
|
||||
}
|
||||
if len(b0) != len(b1) {
|
||||
t.Errorf("%s (order=%d litWidth=%d): length mismatch %d versus %d", fn, order, litWidth, len(b0), len(b1))
|
||||
if len(b1) != len(b0) {
|
||||
t.Errorf("%s (order=%d litWidth=%d): length mismatch %d != %d", fn, order, litWidth, len(b1), len(b0))
|
||||
return
|
||||
}
|
||||
for i := 0; i < len(b0); i++ {
|
||||
if b0[i] != b1[i] {
|
||||
t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, order, litWidth, i, b0[i], b1[i])
|
||||
if b1[i] != b0[i] {
|
||||
t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x != 0x%02x\n", fn, order, litWidth, i, b1[i], b0[i])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,9 +34,9 @@ import (
|
|||
|
||||
const zlibDeflate = 8
|
||||
|
||||
var ChecksumError os.Error = os.ErrorString("zlib checksum error")
|
||||
var HeaderError os.Error = os.ErrorString("invalid zlib header")
|
||||
var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary")
|
||||
var ChecksumError = os.NewError("zlib checksum error")
|
||||
var HeaderError = os.NewError("invalid zlib header")
|
||||
var DictionaryError = os.NewError("invalid zlib dictionary")
|
||||
|
||||
type reader struct {
|
||||
r flate.Reader
|
||||
|
|
|
@ -89,7 +89,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) {
|
|||
}
|
||||
}
|
||||
z.w = w
|
||||
z.compressor = flate.NewWriter(w, level)
|
||||
z.compressor = flate.NewWriterDict(w, level, dict)
|
||||
z.digest = adler32.New()
|
||||
return z, nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
package zlib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -16,15 +18,13 @@ var filenames = []string{
|
|||
"../testdata/pi.txt",
|
||||
}
|
||||
|
||||
var data = []string{
|
||||
"test a reasonable sized string that can be compressed",
|
||||
}
|
||||
|
||||
// Tests that compressing and then decompressing the given file at the given compression level and dictionary
|
||||
// yields equivalent bytes to the original file.
|
||||
func testFileLevelDict(t *testing.T, fn string, level int, d string) {
|
||||
// Read dictionary, if given.
|
||||
var dict []byte
|
||||
if d != "" {
|
||||
dict = []byte(d)
|
||||
}
|
||||
|
||||
// Read the file, as golden output.
|
||||
golden, err := os.Open(fn)
|
||||
if err != nil {
|
||||
|
@ -32,17 +32,25 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
|
|||
return
|
||||
}
|
||||
defer golden.Close()
|
||||
|
||||
// Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end.
|
||||
raw, err := os.Open(fn)
|
||||
if err != nil {
|
||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
|
||||
b0, err0 := ioutil.ReadAll(golden)
|
||||
if err0 != nil {
|
||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
|
||||
return
|
||||
}
|
||||
testLevelDict(t, fn, b0, level, d)
|
||||
}
|
||||
|
||||
func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
|
||||
// Make dictionary, if given.
|
||||
var dict []byte
|
||||
if d != "" {
|
||||
dict = []byte(d)
|
||||
}
|
||||
|
||||
// Push data through a pipe that compresses at the write end, and decompresses at the read end.
|
||||
piper, pipew := io.Pipe()
|
||||
defer piper.Close()
|
||||
go func() {
|
||||
defer raw.Close()
|
||||
defer pipew.Close()
|
||||
zlibw, err := NewWriterDict(pipew, level, dict)
|
||||
if err != nil {
|
||||
|
@ -50,25 +58,14 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
|
|||
return
|
||||
}
|
||||
defer zlibw.Close()
|
||||
var b [1024]byte
|
||||
for {
|
||||
n, err0 := raw.Read(b[0:])
|
||||
if err0 != nil && err0 != os.EOF {
|
||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
|
||||
return
|
||||
}
|
||||
_, err1 := zlibw.Write(b[0:n])
|
||||
if err1 == os.EPIPE {
|
||||
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
|
||||
return
|
||||
}
|
||||
if err1 != nil {
|
||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
|
||||
return
|
||||
}
|
||||
if err0 == os.EOF {
|
||||
break
|
||||
}
|
||||
_, err = zlibw.Write(b0)
|
||||
if err == os.EPIPE {
|
||||
// Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
zlibr, err := NewReaderDict(piper, dict)
|
||||
|
@ -78,13 +75,8 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
|
|||
}
|
||||
defer zlibr.Close()
|
||||
|
||||
// Compare the two.
|
||||
b0, err0 := ioutil.ReadAll(golden)
|
||||
// Compare the decompressed data.
|
||||
b1, err1 := ioutil.ReadAll(zlibr)
|
||||
if err0 != nil {
|
||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
|
||||
return
|
||||
}
|
||||
if err1 != nil {
|
||||
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
|
||||
return
|
||||
|
@ -102,6 +94,18 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
|
|||
}
|
||||
|
||||
func TestWriter(t *testing.T) {
|
||||
for i, s := range data {
|
||||
b := []byte(s)
|
||||
tag := fmt.Sprintf("#%d", i)
|
||||
testLevelDict(t, tag, b, DefaultCompression, "")
|
||||
testLevelDict(t, tag, b, NoCompression, "")
|
||||
for level := BestSpeed; level <= BestCompression; level++ {
|
||||
testLevelDict(t, tag, b, level, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterBig(t *testing.T) {
|
||||
for _, fn := range filenames {
|
||||
testFileLevelDict(t, fn, DefaultCompression, "")
|
||||
testFileLevelDict(t, fn, NoCompression, "")
|
||||
|
@ -121,3 +125,20 @@ func TestWriterDict(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterDictIsUsed(t *testing.T) {
|
||||
var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
buf := bytes.NewBuffer(nil)
|
||||
compressor, err := NewWriterDict(buf, BestCompression, input)
|
||||
if err != nil {
|
||||
t.Errorf("error in NewWriterDict: %s", err)
|
||||
return
|
||||
}
|
||||
compressor.Write(input)
|
||||
compressor.Close()
|
||||
const expectedMaxSize = 25
|
||||
output := buf.Bytes()
|
||||
if len(output) > expectedMaxSize {
|
||||
t.Errorf("result too large (got %d, want <= %d bytes). Is the dictionary being used?", len(output), expectedMaxSize)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,8 +21,7 @@ type Interface interface {
|
|||
Pop() interface{}
|
||||
}
|
||||
|
||||
|
||||
// A heaper must be initialized before any of the heap operations
|
||||
// A heap must be initialized before any of the heap operations
|
||||
// can be used. Init is idempotent with respect to the heap invariants
|
||||
// and may be called whenever the heap invariants may have been invalidated.
|
||||
// Its complexity is O(n) where n = h.Len().
|
||||
|
@ -35,7 +34,6 @@ func Init(h Interface) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Push pushes the element x onto the heap. The complexity is
|
||||
// O(log(n)) where n = h.Len().
|
||||
//
|
||||
|
@ -44,7 +42,6 @@ func Push(h Interface, x interface{}) {
|
|||
up(h, h.Len()-1)
|
||||
}
|
||||
|
||||
|
||||
// Pop removes the minimum element (according to Less) from the heap
|
||||
// and returns it. The complexity is O(log(n)) where n = h.Len().
|
||||
// Same as Remove(h, 0).
|
||||
|
@ -56,7 +53,6 @@ func Pop(h Interface) interface{} {
|
|||
return h.Pop()
|
||||
}
|
||||
|
||||
|
||||
// Remove removes the element at index i from the heap.
|
||||
// The complexity is O(log(n)) where n = h.Len().
|
||||
//
|
||||
|
@ -70,7 +66,6 @@ func Remove(h Interface, i int) interface{} {
|
|||
return h.Pop()
|
||||
}
|
||||
|
||||
|
||||
func up(h Interface, j int) {
|
||||
for {
|
||||
i := (j - 1) / 2 // parent
|
||||
|
@ -82,7 +77,6 @@ func up(h Interface, j int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func down(h Interface, i, n int) {
|
||||
for {
|
||||
j1 := 2*i + 1
|
||||
|
|
|
@ -10,17 +10,14 @@ import (
|
|||
. "container/heap"
|
||||
)
|
||||
|
||||
|
||||
type myHeap struct {
|
||||
// A vector.Vector implements sort.Interface except for Less,
|
||||
// and it implements Push and Pop as required for heap.Interface.
|
||||
vector.Vector
|
||||
}
|
||||
|
||||
|
||||
func (h *myHeap) Less(i, j int) bool { return h.At(i).(int) < h.At(j).(int) }
|
||||
|
||||
|
||||
func (h *myHeap) verify(t *testing.T, i int) {
|
||||
n := h.Len()
|
||||
j1 := 2*i + 1
|
||||
|
@ -41,7 +38,6 @@ func (h *myHeap) verify(t *testing.T, i int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestInit0(t *testing.T) {
|
||||
h := new(myHeap)
|
||||
for i := 20; i > 0; i-- {
|
||||
|
@ -59,7 +55,6 @@ func TestInit0(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestInit1(t *testing.T) {
|
||||
h := new(myHeap)
|
||||
for i := 20; i > 0; i-- {
|
||||
|
@ -77,7 +72,6 @@ func TestInit1(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func Test(t *testing.T) {
|
||||
h := new(myHeap)
|
||||
h.verify(t, 0)
|
||||
|
@ -105,7 +99,6 @@ func Test(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRemove0(t *testing.T) {
|
||||
h := new(myHeap)
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -123,7 +116,6 @@ func TestRemove0(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRemove1(t *testing.T) {
|
||||
h := new(myHeap)
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -140,7 +132,6 @@ func TestRemove1(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestRemove2(t *testing.T) {
|
||||
N := 10
|
||||
|
||||
|
|
|
@ -16,14 +16,12 @@ type Ring struct {
|
|||
Value interface{} // for use by client; untouched by this library
|
||||
}
|
||||
|
||||
|
||||
func (r *Ring) init() *Ring {
|
||||
r.next = r
|
||||
r.prev = r
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
// Next returns the next ring element. r must not be empty.
|
||||
func (r *Ring) Next() *Ring {
|
||||
if r.next == nil {
|
||||
|
@ -32,7 +30,6 @@ func (r *Ring) Next() *Ring {
|
|||
return r.next
|
||||
}
|
||||
|
||||
|
||||
// Prev returns the previous ring element. r must not be empty.
|
||||
func (r *Ring) Prev() *Ring {
|
||||
if r.next == nil {
|
||||
|
@ -41,7 +38,6 @@ func (r *Ring) Prev() *Ring {
|
|||
return r.prev
|
||||
}
|
||||
|
||||
|
||||
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
|
||||
// in the ring and returns that ring element. r must not be empty.
|
||||
//
|
||||
|
@ -62,7 +58,6 @@ func (r *Ring) Move(n int) *Ring {
|
|||
return r
|
||||
}
|
||||
|
||||
|
||||
// New creates a ring of n elements.
|
||||
func New(n int) *Ring {
|
||||
if n <= 0 {
|
||||
|
@ -79,7 +74,6 @@ func New(n int) *Ring {
|
|||
return r
|
||||
}
|
||||
|
||||
|
||||
// Link connects ring r with with ring s such that r.Next()
|
||||
// becomes s and returns the original value for r.Next().
|
||||
// r must not be empty.
|
||||
|
@ -110,7 +104,6 @@ func (r *Ring) Link(s *Ring) *Ring {
|
|||
return n
|
||||
}
|
||||
|
||||
|
||||
// Unlink removes n % r.Len() elements from the ring r, starting
|
||||
// at r.Next(). If n % r.Len() == 0, r remains unchanged.
|
||||
// The result is the removed subring. r must not be empty.
|
||||
|
@ -122,7 +115,6 @@ func (r *Ring) Unlink(n int) *Ring {
|
|||
return r.Link(r.Move(n + 1))
|
||||
}
|
||||
|
||||
|
||||
// Len computes the number of elements in ring r.
|
||||
// It executes in time proportional to the number of elements.
|
||||
//
|
||||
|
@ -137,7 +129,6 @@ func (r *Ring) Len() int {
|
|||
return n
|
||||
}
|
||||
|
||||
|
||||
// Do calls function f on each element of the ring, in forward order.
|
||||
// The behavior of Do is undefined if f changes *r.
|
||||
func (r *Ring) Do(f func(interface{})) {
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
|
||||
// For debugging - keep around.
|
||||
func dump(r *Ring) {
|
||||
if r == nil {
|
||||
|
@ -24,7 +23,6 @@ func dump(r *Ring) {
|
|||
fmt.Println()
|
||||
}
|
||||
|
||||
|
||||
func verify(t *testing.T, r *Ring, N int, sum int) {
|
||||
// Len
|
||||
n := r.Len()
|
||||
|
@ -96,7 +94,6 @@ func verify(t *testing.T, r *Ring, N int, sum int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestCornerCases(t *testing.T) {
|
||||
var (
|
||||
r0 *Ring
|
||||
|
@ -118,7 +115,6 @@ func TestCornerCases(t *testing.T) {
|
|||
verify(t, &r1, 1, 0)
|
||||
}
|
||||
|
||||
|
||||
func makeN(n int) *Ring {
|
||||
r := New(n)
|
||||
for i := 1; i <= n; i++ {
|
||||
|
@ -130,7 +126,6 @@ func makeN(n int) *Ring {
|
|||
|
||||
func sumN(n int) int { return (n*n + n) / 2 }
|
||||
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
r := New(i)
|
||||
|
@ -142,7 +137,6 @@ func TestNew(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestLink1(t *testing.T) {
|
||||
r1a := makeN(1)
|
||||
var r1b Ring
|
||||
|
@ -163,7 +157,6 @@ func TestLink1(t *testing.T) {
|
|||
verify(t, r2b, 1, 0)
|
||||
}
|
||||
|
||||
|
||||
func TestLink2(t *testing.T) {
|
||||
var r0 *Ring
|
||||
r1a := &Ring{Value: 42}
|
||||
|
@ -183,7 +176,6 @@ func TestLink2(t *testing.T) {
|
|||
verify(t, r10, 12, sumN(10)+42+77)
|
||||
}
|
||||
|
||||
|
||||
func TestLink3(t *testing.T) {
|
||||
var r Ring
|
||||
n := 1
|
||||
|
@ -193,7 +185,6 @@ func TestLink3(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
r10 := makeN(10)
|
||||
s10 := r10.Move(6)
|
||||
|
@ -215,7 +206,6 @@ func TestUnlink(t *testing.T) {
|
|||
verify(t, r10, 9, sum10-2)
|
||||
}
|
||||
|
||||
|
||||
func TestLinkUnlink(t *testing.T) {
|
||||
for i := 1; i < 4; i++ {
|
||||
ri := New(i)
|
||||
|
|
|
@ -6,29 +6,24 @@
|
|||
// Vectors grow and shrink dynamically as necessary.
|
||||
package vector
|
||||
|
||||
|
||||
// Vector is a container for numbered sequences of elements of type interface{}.
|
||||
// A vector's length and capacity adjusts automatically as necessary.
|
||||
// The zero value for Vector is an empty vector ready to use.
|
||||
type Vector []interface{}
|
||||
|
||||
|
||||
// IntVector is a container for numbered sequences of elements of type int.
|
||||
// A vector's length and capacity adjusts automatically as necessary.
|
||||
// The zero value for IntVector is an empty vector ready to use.
|
||||
type IntVector []int
|
||||
|
||||
|
||||
// StringVector is a container for numbered sequences of elements of type string.
|
||||
// A vector's length and capacity adjusts automatically as necessary.
|
||||
// The zero value for StringVector is an empty vector ready to use.
|
||||
type StringVector []string
|
||||
|
||||
|
||||
// Initial underlying array size
|
||||
const initialSize = 8
|
||||
|
||||
|
||||
// Partial sort.Interface support
|
||||
|
||||
// LessInterface provides partial support of the sort.Interface.
|
||||
|
@ -36,16 +31,13 @@ type LessInterface interface {
|
|||
Less(y interface{}) bool
|
||||
}
|
||||
|
||||
|
||||
// Less returns a boolean denoting whether the i'th element is less than the j'th element.
|
||||
func (p *Vector) Less(i, j int) bool { return (*p)[i].(LessInterface).Less((*p)[j]) }
|
||||
|
||||
|
||||
// sort.Interface support
|
||||
|
||||
// Less returns a boolean denoting whether the i'th element is less than the j'th element.
|
||||
func (p *IntVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }
|
||||
|
||||
|
||||
// Less returns a boolean denoting whether the i'th element is less than the j'th element.
|
||||
func (p *StringVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package vector
|
||||
|
||||
|
||||
func (p *IntVector) realloc(length, capacity int) (b []int) {
|
||||
if capacity < initialSize {
|
||||
capacity = initialSize
|
||||
|
@ -21,7 +20,6 @@ func (p *IntVector) realloc(length, capacity int) (b []int) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// Insert n elements at position i.
|
||||
func (p *IntVector) Expand(i, n int) {
|
||||
a := *p
|
||||
|
@ -51,11 +49,9 @@ func (p *IntVector) Expand(i, n int) {
|
|||
*p = a
|
||||
}
|
||||
|
||||
|
||||
// Insert n elements at the end of a vector.
|
||||
func (p *IntVector) Extend(n int) { p.Expand(len(*p), n) }
|
||||
|
||||
|
||||
// Resize changes the length and capacity of a vector.
|
||||
// If the new length is shorter than the current length, Resize discards
|
||||
// trailing elements. If the new length is longer than the current length,
|
||||
|
@ -80,30 +76,24 @@ func (p *IntVector) Resize(length, capacity int) *IntVector {
|
|||
return p
|
||||
}
|
||||
|
||||
|
||||
// Len returns the number of elements in the vector.
|
||||
// Same as len(*p).
|
||||
func (p *IntVector) Len() int { return len(*p) }
|
||||
|
||||
|
||||
// Cap returns the capacity of the vector; that is, the
|
||||
// maximum length the vector can grow without resizing.
|
||||
// Same as cap(*p).
|
||||
func (p *IntVector) Cap() int { return cap(*p) }
|
||||
|
||||
|
||||
// At returns the i'th element of the vector.
|
||||
func (p *IntVector) At(i int) int { return (*p)[i] }
|
||||
|
||||
|
||||
// Set sets the i'th element of the vector to value x.
|
||||
func (p *IntVector) Set(i int, x int) { (*p)[i] = x }
|
||||
|
||||
|
||||
// Last returns the element in the vector of highest index.
|
||||
func (p *IntVector) Last() int { return (*p)[len(*p)-1] }
|
||||
|
||||
|
||||
// Copy makes a copy of the vector and returns it.
|
||||
func (p *IntVector) Copy() IntVector {
|
||||
arr := make(IntVector, len(*p))
|
||||
|
@ -111,7 +101,6 @@ func (p *IntVector) Copy() IntVector {
|
|||
return arr
|
||||
}
|
||||
|
||||
|
||||
// Insert inserts into the vector an element of value x before
|
||||
// the current element at index i.
|
||||
func (p *IntVector) Insert(i int, x int) {
|
||||
|
@ -119,7 +108,6 @@ func (p *IntVector) Insert(i int, x int) {
|
|||
(*p)[i] = x
|
||||
}
|
||||
|
||||
|
||||
// Delete deletes the i'th element of the vector. The gap is closed so the old
|
||||
// element at index i+1 has index i afterwards.
|
||||
func (p *IntVector) Delete(i int) {
|
||||
|
@ -132,7 +120,6 @@ func (p *IntVector) Delete(i int) {
|
|||
*p = a[0 : n-1]
|
||||
}
|
||||
|
||||
|
||||
// InsertVector inserts into the vector the contents of the vector
|
||||
// x such that the 0th element of x appears at index i after insertion.
|
||||
func (p *IntVector) InsertVector(i int, x *IntVector) {
|
||||
|
@ -142,7 +129,6 @@ func (p *IntVector) InsertVector(i int, x *IntVector) {
|
|||
copy((*p)[i:i+len(b)], b)
|
||||
}
|
||||
|
||||
|
||||
// Cut deletes elements i through j-1, inclusive.
|
||||
func (p *IntVector) Cut(i, j int) {
|
||||
a := *p
|
||||
|
@ -158,7 +144,6 @@ func (p *IntVector) Cut(i, j int) {
|
|||
*p = a[0:m]
|
||||
}
|
||||
|
||||
|
||||
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
|
||||
// The elements are copied. The original vector is unchanged.
|
||||
func (p *IntVector) Slice(i, j int) *IntVector {
|
||||
|
@ -168,13 +153,11 @@ func (p *IntVector) Slice(i, j int) *IntVector {
|
|||
return &s
|
||||
}
|
||||
|
||||
|
||||
// Convenience wrappers
|
||||
|
||||
// Push appends x to the end of the vector.
|
||||
func (p *IntVector) Push(x int) { p.Insert(len(*p), x) }
|
||||
|
||||
|
||||
// Pop deletes the last element of the vector.
|
||||
func (p *IntVector) Pop() int {
|
||||
a := *p
|
||||
|
@ -187,18 +170,15 @@ func (p *IntVector) Pop() int {
|
|||
return x
|
||||
}
|
||||
|
||||
|
||||
// AppendVector appends the entire vector x to the end of this vector.
|
||||
func (p *IntVector) AppendVector(x *IntVector) { p.InsertVector(len(*p), x) }
|
||||
|
||||
|
||||
// Swap exchanges the elements at indexes i and j.
|
||||
func (p *IntVector) Swap(i, j int) {
|
||||
a := *p
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
|
||||
|
||||
// Do calls function f for each element of the vector, in order.
|
||||
// The behavior of Do is undefined if f changes *p.
|
||||
func (p *IntVector) Do(f func(elem int)) {
|
||||
|
|
|
@ -9,7 +9,6 @@ package vector
|
|||
|
||||
import "testing"
|
||||
|
||||
|
||||
func TestIntZeroLen(t *testing.T) {
|
||||
a := new(IntVector)
|
||||
if a.Len() != 0 {
|
||||
|
@ -27,7 +26,6 @@ func TestIntZeroLen(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestIntResize(t *testing.T) {
|
||||
var a IntVector
|
||||
checkSize(t, &a, 0, 0)
|
||||
|
@ -40,7 +38,6 @@ func TestIntResize(t *testing.T) {
|
|||
checkSize(t, a.Resize(11, 100), 11, 100)
|
||||
}
|
||||
|
||||
|
||||
func TestIntResize2(t *testing.T) {
|
||||
var a IntVector
|
||||
checkSize(t, &a, 0, 0)
|
||||
|
@ -62,7 +59,6 @@ func TestIntResize2(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func checkIntZero(t *testing.T, a *IntVector, i int) {
|
||||
for j := 0; j < i; j++ {
|
||||
if a.At(j) == intzero {
|
||||
|
@ -82,7 +78,6 @@ func checkIntZero(t *testing.T, a *IntVector, i int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestIntTrailingElements(t *testing.T) {
|
||||
var a IntVector
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -95,7 +90,6 @@ func TestIntTrailingElements(t *testing.T) {
|
|||
checkIntZero(t, &a, 5)
|
||||
}
|
||||
|
||||
|
||||
func TestIntAccess(t *testing.T) {
|
||||
const n = 100
|
||||
var a IntVector
|
||||
|
@ -120,7 +114,6 @@ func TestIntAccess(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestIntInsertDeleteClear(t *testing.T) {
|
||||
const n = 100
|
||||
var a IntVector
|
||||
|
@ -207,7 +200,6 @@ func TestIntInsertDeleteClear(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
|
||||
for k := i; k < j; k++ {
|
||||
if elem2IntValue(x.At(k)) != int2IntValue(elt) {
|
||||
|
@ -223,7 +215,6 @@ func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
|
||||
n := a + b + c
|
||||
if x.Len() != n {
|
||||
|
@ -237,7 +228,6 @@ func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
|
|||
verify_sliceInt(t, x, 0, a+b, n)
|
||||
}
|
||||
|
||||
|
||||
func make_vectorInt(elt, len int) *IntVector {
|
||||
x := new(IntVector).Resize(len, 0)
|
||||
for i := 0; i < len; i++ {
|
||||
|
@ -246,7 +236,6 @@ func make_vectorInt(elt, len int) *IntVector {
|
|||
return x
|
||||
}
|
||||
|
||||
|
||||
func TestIntInsertVector(t *testing.T) {
|
||||
// 1
|
||||
a := make_vectorInt(0, 0)
|
||||
|
@ -270,7 +259,6 @@ func TestIntInsertVector(t *testing.T) {
|
|||
verify_patternInt(t, a, 8, 1000, 2)
|
||||
}
|
||||
|
||||
|
||||
func TestIntDo(t *testing.T) {
|
||||
const n = 25
|
||||
const salt = 17
|
||||
|
@ -325,7 +313,6 @@ func TestIntDo(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
func TestIntVectorCopy(t *testing.T) {
|
||||
// verify Copy() returns a copy, not simply a slice of the original vector
|
||||
const Len = 10
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
package vector
|
||||
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
@ -17,28 +16,23 @@ var (
|
|||
strzero string
|
||||
)
|
||||
|
||||
|
||||
func int2Value(x int) int { return x }
|
||||
func int2IntValue(x int) int { return x }
|
||||
func int2StrValue(x int) string { return string(x) }
|
||||
|
||||
|
||||
func elem2Value(x interface{}) int { return x.(int) }
|
||||
func elem2IntValue(x int) int { return x }
|
||||
func elem2StrValue(x string) string { return x }
|
||||
|
||||
|
||||
func intf2Value(x interface{}) int { return x.(int) }
|
||||
func intf2IntValue(x interface{}) int { return x.(int) }
|
||||
func intf2StrValue(x interface{}) string { return x.(string) }
|
||||
|
||||
|
||||
type VectorInterface interface {
|
||||
Len() int
|
||||
Cap() int
|
||||
}
|
||||
|
||||
|
||||
func checkSize(t *testing.T, v VectorInterface, len, cap int) {
|
||||
if v.Len() != len {
|
||||
t.Errorf("%T expected len = %d; found %d", v, len, v.Len())
|
||||
|
@ -48,10 +42,8 @@ func checkSize(t *testing.T, v VectorInterface, len, cap int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func val(i int) int { return i*991 - 1234 }
|
||||
|
||||
|
||||
func TestSorting(t *testing.T) {
|
||||
const n = 100
|
||||
|
||||
|
@ -72,5 +64,4 @@ func TestSorting(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func tname(x interface{}) string { return fmt.Sprintf("%T: ", x) }
|
||||
|
|
|
@ -11,10 +11,8 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
|
||||
const memTestN = 1000000
|
||||
|
||||
|
||||
func s(n uint64) string {
|
||||
str := fmt.Sprintf("%d", n)
|
||||
lens := len(str)
|
||||
|
@ -31,7 +29,6 @@ func s(n uint64) string {
|
|||
return strings.Join(a, " ")
|
||||
}
|
||||
|
||||
|
||||
func TestVectorNums(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
|
@ -52,7 +49,6 @@ func TestVectorNums(t *testing.T) {
|
|||
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
|
||||
}
|
||||
|
||||
|
||||
func TestIntVectorNums(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
|
@ -73,7 +69,6 @@ func TestIntVectorNums(t *testing.T) {
|
|||
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
|
||||
}
|
||||
|
||||
|
||||
func TestStringVectorNums(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
|
@ -94,7 +89,6 @@ func TestStringVectorNums(t *testing.T) {
|
|||
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
|
||||
}
|
||||
|
||||
|
||||
func BenchmarkVectorNums(b *testing.B) {
|
||||
c := int(0)
|
||||
var v Vector
|
||||
|
@ -106,7 +100,6 @@ func BenchmarkVectorNums(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func BenchmarkIntVectorNums(b *testing.B) {
|
||||
c := int(0)
|
||||
var v IntVector
|
||||
|
@ -118,7 +111,6 @@ func BenchmarkIntVectorNums(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func BenchmarkStringVectorNums(b *testing.B) {
|
||||
c := ""
|
||||
var v StringVector
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package vector
|
||||
|
||||
|
||||
func (p *StringVector) realloc(length, capacity int) (b []string) {
|
||||
if capacity < initialSize {
|
||||
capacity = initialSize
|
||||
|
@ -21,7 +20,6 @@ func (p *StringVector) realloc(length, capacity int) (b []string) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// Insert n elements at position i.
|
||||
func (p *StringVector) Expand(i, n int) {
|
||||
a := *p
|
||||
|
@ -51,11 +49,9 @@ func (p *StringVector) Expand(i, n int) {
|
|||
*p = a
|
||||
}
|
||||
|
||||
|
||||
// Insert n elements at the end of a vector.
|
||||
func (p *StringVector) Extend(n int) { p.Expand(len(*p), n) }
|
||||
|
||||
|
||||
// Resize changes the length and capacity of a vector.
|
||||
// If the new length is shorter than the current length, Resize discards
|
||||
// trailing elements. If the new length is longer than the current length,
|
||||
|
@ -80,30 +76,24 @@ func (p *StringVector) Resize(length, capacity int) *StringVector {
|
|||
return p
|
||||
}
|
||||
|
||||
|
||||
// Len returns the number of elements in the vector.
|
||||
// Same as len(*p).
|
||||
func (p *StringVector) Len() int { return len(*p) }
|
||||
|
||||
|
||||
// Cap returns the capacity of the vector; that is, the
|
||||
// maximum length the vector can grow without resizing.
|
||||
// Same as cap(*p).
|
||||
func (p *StringVector) Cap() int { return cap(*p) }
|
||||
|
||||
|
||||
// At returns the i'th element of the vector.
|
||||
func (p *StringVector) At(i int) string { return (*p)[i] }
|
||||
|
||||
|
||||
// Set sets the i'th element of the vector to value x.
|
||||
func (p *StringVector) Set(i int, x string) { (*p)[i] = x }
|
||||
|
||||
|
||||
// Last returns the element in the vector of highest index.
|
||||
func (p *StringVector) Last() string { return (*p)[len(*p)-1] }
|
||||
|
||||
|
||||
// Copy makes a copy of the vector and returns it.
|
||||
func (p *StringVector) Copy() StringVector {
|
||||
arr := make(StringVector, len(*p))
|
||||
|
@ -111,7 +101,6 @@ func (p *StringVector) Copy() StringVector {
|
|||
return arr
|
||||
}
|
||||
|
||||
|
||||
// Insert inserts into the vector an element of value x before
|
||||
// the current element at index i.
|
||||
func (p *StringVector) Insert(i int, x string) {
|
||||
|
@ -119,7 +108,6 @@ func (p *StringVector) Insert(i int, x string) {
|
|||
(*p)[i] = x
|
||||
}
|
||||
|
||||
|
||||
// Delete deletes the i'th element of the vector. The gap is closed so the old
|
||||
// element at index i+1 has index i afterwards.
|
||||
func (p *StringVector) Delete(i int) {
|
||||
|
@ -132,7 +120,6 @@ func (p *StringVector) Delete(i int) {
|
|||
*p = a[0 : n-1]
|
||||
}
|
||||
|
||||
|
||||
// InsertVector inserts into the vector the contents of the vector
|
||||
// x such that the 0th element of x appears at index i after insertion.
|
||||
func (p *StringVector) InsertVector(i int, x *StringVector) {
|
||||
|
@ -142,7 +129,6 @@ func (p *StringVector) InsertVector(i int, x *StringVector) {
|
|||
copy((*p)[i:i+len(b)], b)
|
||||
}
|
||||
|
||||
|
||||
// Cut deletes elements i through j-1, inclusive.
|
||||
func (p *StringVector) Cut(i, j int) {
|
||||
a := *p
|
||||
|
@ -158,7 +144,6 @@ func (p *StringVector) Cut(i, j int) {
|
|||
*p = a[0:m]
|
||||
}
|
||||
|
||||
|
||||
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
|
||||
// The elements are copied. The original vector is unchanged.
|
||||
func (p *StringVector) Slice(i, j int) *StringVector {
|
||||
|
@ -168,13 +153,11 @@ func (p *StringVector) Slice(i, j int) *StringVector {
|
|||
return &s
|
||||
}
|
||||
|
||||
|
||||
// Convenience wrappers
|
||||
|
||||
// Push appends x to the end of the vector.
|
||||
func (p *StringVector) Push(x string) { p.Insert(len(*p), x) }
|
||||
|
||||
|
||||
// Pop deletes the last element of the vector.
|
||||
func (p *StringVector) Pop() string {
|
||||
a := *p
|
||||
|
@ -187,18 +170,15 @@ func (p *StringVector) Pop() string {
|
|||
return x
|
||||
}
|
||||
|
||||
|
||||
// AppendVector appends the entire vector x to the end of this vector.
|
||||
func (p *StringVector) AppendVector(x *StringVector) { p.InsertVector(len(*p), x) }
|
||||
|
||||
|
||||
// Swap exchanges the elements at indexes i and j.
|
||||
func (p *StringVector) Swap(i, j int) {
|
||||
a := *p
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
|
||||
|
||||
// Do calls function f for each element of the vector, in order.
|
||||
// The behavior of Do is undefined if f changes *p.
|
||||
func (p *StringVector) Do(f func(elem string)) {
|
||||
|
|
|
@ -9,7 +9,6 @@ package vector
|
|||
|
||||
import "testing"
|
||||
|
||||
|
||||
func TestStrZeroLen(t *testing.T) {
|
||||
a := new(StringVector)
|
||||
if a.Len() != 0 {
|
||||
|
@ -27,7 +26,6 @@ func TestStrZeroLen(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestStrResize(t *testing.T) {
|
||||
var a StringVector
|
||||
checkSize(t, &a, 0, 0)
|
||||
|
@ -40,7 +38,6 @@ func TestStrResize(t *testing.T) {
|
|||
checkSize(t, a.Resize(11, 100), 11, 100)
|
||||
}
|
||||
|
||||
|
||||
func TestStrResize2(t *testing.T) {
|
||||
var a StringVector
|
||||
checkSize(t, &a, 0, 0)
|
||||
|
@ -62,7 +59,6 @@ func TestStrResize2(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func checkStrZero(t *testing.T, a *StringVector, i int) {
|
||||
for j := 0; j < i; j++ {
|
||||
if a.At(j) == strzero {
|
||||
|
@ -82,7 +78,6 @@ func checkStrZero(t *testing.T, a *StringVector, i int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestStrTrailingElements(t *testing.T) {
|
||||
var a StringVector
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -95,7 +90,6 @@ func TestStrTrailingElements(t *testing.T) {
|
|||
checkStrZero(t, &a, 5)
|
||||
}
|
||||
|
||||
|
||||
func TestStrAccess(t *testing.T) {
|
||||
const n = 100
|
||||
var a StringVector
|
||||
|
@ -120,7 +114,6 @@ func TestStrAccess(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestStrInsertDeleteClear(t *testing.T) {
|
||||
const n = 100
|
||||
var a StringVector
|
||||
|
@ -207,7 +200,6 @@ func TestStrInsertDeleteClear(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
|
||||
for k := i; k < j; k++ {
|
||||
if elem2StrValue(x.At(k)) != int2StrValue(elt) {
|
||||
|
@ -223,7 +215,6 @@ func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
|
||||
n := a + b + c
|
||||
if x.Len() != n {
|
||||
|
@ -237,7 +228,6 @@ func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
|
|||
verify_sliceStr(t, x, 0, a+b, n)
|
||||
}
|
||||
|
||||
|
||||
func make_vectorStr(elt, len int) *StringVector {
|
||||
x := new(StringVector).Resize(len, 0)
|
||||
for i := 0; i < len; i++ {
|
||||
|
@ -246,7 +236,6 @@ func make_vectorStr(elt, len int) *StringVector {
|
|||
return x
|
||||
}
|
||||
|
||||
|
||||
func TestStrInsertVector(t *testing.T) {
|
||||
// 1
|
||||
a := make_vectorStr(0, 0)
|
||||
|
@ -270,7 +259,6 @@ func TestStrInsertVector(t *testing.T) {
|
|||
verify_patternStr(t, a, 8, 1000, 2)
|
||||
}
|
||||
|
||||
|
||||
func TestStrDo(t *testing.T) {
|
||||
const n = 25
|
||||
const salt = 17
|
||||
|
@ -325,7 +313,6 @@ func TestStrDo(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
func TestStrVectorCopy(t *testing.T) {
|
||||
// verify Copy() returns a copy, not simply a slice of the original vector
|
||||
const Len = 10
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
package vector
|
||||
|
||||
|
||||
func (p *Vector) realloc(length, capacity int) (b []interface{}) {
|
||||
if capacity < initialSize {
|
||||
capacity = initialSize
|
||||
|
@ -21,7 +20,6 @@ func (p *Vector) realloc(length, capacity int) (b []interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// Insert n elements at position i.
|
||||
func (p *Vector) Expand(i, n int) {
|
||||
a := *p
|
||||
|
@ -51,11 +49,9 @@ func (p *Vector) Expand(i, n int) {
|
|||
*p = a
|
||||
}
|
||||
|
||||
|
||||
// Insert n elements at the end of a vector.
|
||||
func (p *Vector) Extend(n int) { p.Expand(len(*p), n) }
|
||||
|
||||
|
||||
// Resize changes the length and capacity of a vector.
|
||||
// If the new length is shorter than the current length, Resize discards
|
||||
// trailing elements. If the new length is longer than the current length,
|
||||
|
@ -80,30 +76,24 @@ func (p *Vector) Resize(length, capacity int) *Vector {
|
|||
return p
|
||||
}
|
||||
|
||||
|
||||
// Len returns the number of elements in the vector.
|
||||
// Same as len(*p).
|
||||
func (p *Vector) Len() int { return len(*p) }
|
||||
|
||||
|
||||
// Cap returns the capacity of the vector; that is, the
|
||||
// maximum length the vector can grow without resizing.
|
||||
// Same as cap(*p).
|
||||
func (p *Vector) Cap() int { return cap(*p) }
|
||||
|
||||
|
||||
// At returns the i'th element of the vector.
|
||||
func (p *Vector) At(i int) interface{} { return (*p)[i] }
|
||||
|
||||
|
||||
// Set sets the i'th element of the vector to value x.
|
||||
func (p *Vector) Set(i int, x interface{}) { (*p)[i] = x }
|
||||
|
||||
|
||||
// Last returns the element in the vector of highest index.
|
||||
func (p *Vector) Last() interface{} { return (*p)[len(*p)-1] }
|
||||
|
||||
|
||||
// Copy makes a copy of the vector and returns it.
|
||||
func (p *Vector) Copy() Vector {
|
||||
arr := make(Vector, len(*p))
|
||||
|
@ -111,7 +101,6 @@ func (p *Vector) Copy() Vector {
|
|||
return arr
|
||||
}
|
||||
|
||||
|
||||
// Insert inserts into the vector an element of value x before
|
||||
// the current element at index i.
|
||||
func (p *Vector) Insert(i int, x interface{}) {
|
||||
|
@ -119,7 +108,6 @@ func (p *Vector) Insert(i int, x interface{}) {
|
|||
(*p)[i] = x
|
||||
}
|
||||
|
||||
|
||||
// Delete deletes the i'th element of the vector. The gap is closed so the old
|
||||
// element at index i+1 has index i afterwards.
|
||||
func (p *Vector) Delete(i int) {
|
||||
|
@ -132,7 +120,6 @@ func (p *Vector) Delete(i int) {
|
|||
*p = a[0 : n-1]
|
||||
}
|
||||
|
||||
|
||||
// InsertVector inserts into the vector the contents of the vector
|
||||
// x such that the 0th element of x appears at index i after insertion.
|
||||
func (p *Vector) InsertVector(i int, x *Vector) {
|
||||
|
@ -142,7 +129,6 @@ func (p *Vector) InsertVector(i int, x *Vector) {
|
|||
copy((*p)[i:i+len(b)], b)
|
||||
}
|
||||
|
||||
|
||||
// Cut deletes elements i through j-1, inclusive.
|
||||
func (p *Vector) Cut(i, j int) {
|
||||
a := *p
|
||||
|
@ -158,7 +144,6 @@ func (p *Vector) Cut(i, j int) {
|
|||
*p = a[0:m]
|
||||
}
|
||||
|
||||
|
||||
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
|
||||
// The elements are copied. The original vector is unchanged.
|
||||
func (p *Vector) Slice(i, j int) *Vector {
|
||||
|
@ -168,13 +153,11 @@ func (p *Vector) Slice(i, j int) *Vector {
|
|||
return &s
|
||||
}
|
||||
|
||||
|
||||
// Convenience wrappers
|
||||
|
||||
// Push appends x to the end of the vector.
|
||||
func (p *Vector) Push(x interface{}) { p.Insert(len(*p), x) }
|
||||
|
||||
|
||||
// Pop deletes the last element of the vector.
|
||||
func (p *Vector) Pop() interface{} {
|
||||
a := *p
|
||||
|
@ -187,18 +170,15 @@ func (p *Vector) Pop() interface{} {
|
|||
return x
|
||||
}
|
||||
|
||||
|
||||
// AppendVector appends the entire vector x to the end of this vector.
|
||||
func (p *Vector) AppendVector(x *Vector) { p.InsertVector(len(*p), x) }
|
||||
|
||||
|
||||
// Swap exchanges the elements at indexes i and j.
|
||||
func (p *Vector) Swap(i, j int) {
|
||||
a := *p
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
|
||||
|
||||
// Do calls function f for each element of the vector, in order.
|
||||
// The behavior of Do is undefined if f changes *p.
|
||||
func (p *Vector) Do(f func(elem interface{})) {
|
||||
|
|
|
@ -9,7 +9,6 @@ package vector
|
|||
|
||||
import "testing"
|
||||
|
||||
|
||||
func TestZeroLen(t *testing.T) {
|
||||
a := new(Vector)
|
||||
if a.Len() != 0 {
|
||||
|
@ -27,7 +26,6 @@ func TestZeroLen(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestResize(t *testing.T) {
|
||||
var a Vector
|
||||
checkSize(t, &a, 0, 0)
|
||||
|
@ -40,7 +38,6 @@ func TestResize(t *testing.T) {
|
|||
checkSize(t, a.Resize(11, 100), 11, 100)
|
||||
}
|
||||
|
||||
|
||||
func TestResize2(t *testing.T) {
|
||||
var a Vector
|
||||
checkSize(t, &a, 0, 0)
|
||||
|
@ -62,7 +59,6 @@ func TestResize2(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func checkZero(t *testing.T, a *Vector, i int) {
|
||||
for j := 0; j < i; j++ {
|
||||
if a.At(j) == zero {
|
||||
|
@ -82,7 +78,6 @@ func checkZero(t *testing.T, a *Vector, i int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestTrailingElements(t *testing.T) {
|
||||
var a Vector
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -95,7 +90,6 @@ func TestTrailingElements(t *testing.T) {
|
|||
checkZero(t, &a, 5)
|
||||
}
|
||||
|
||||
|
||||
func TestAccess(t *testing.T) {
|
||||
const n = 100
|
||||
var a Vector
|
||||
|
@ -120,7 +114,6 @@ func TestAccess(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestInsertDeleteClear(t *testing.T) {
|
||||
const n = 100
|
||||
var a Vector
|
||||
|
@ -207,7 +200,6 @@ func TestInsertDeleteClear(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func verify_slice(t *testing.T, x *Vector, elt, i, j int) {
|
||||
for k := i; k < j; k++ {
|
||||
if elem2Value(x.At(k)) != int2Value(elt) {
|
||||
|
@ -223,7 +215,6 @@ func verify_slice(t *testing.T, x *Vector, elt, i, j int) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func verify_pattern(t *testing.T, x *Vector, a, b, c int) {
|
||||
n := a + b + c
|
||||
if x.Len() != n {
|
||||
|
@ -237,7 +228,6 @@ func verify_pattern(t *testing.T, x *Vector, a, b, c int) {
|
|||
verify_slice(t, x, 0, a+b, n)
|
||||
}
|
||||
|
||||
|
||||
func make_vector(elt, len int) *Vector {
|
||||
x := new(Vector).Resize(len, 0)
|
||||
for i := 0; i < len; i++ {
|
||||
|
@ -246,7 +236,6 @@ func make_vector(elt, len int) *Vector {
|
|||
return x
|
||||
}
|
||||
|
||||
|
||||
func TestInsertVector(t *testing.T) {
|
||||
// 1
|
||||
a := make_vector(0, 0)
|
||||
|
@ -270,7 +259,6 @@ func TestInsertVector(t *testing.T) {
|
|||
verify_pattern(t, a, 8, 1000, 2)
|
||||
}
|
||||
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
const n = 25
|
||||
const salt = 17
|
||||
|
@ -325,7 +313,6 @@ func TestDo(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
func TestVectorCopy(t *testing.T) {
|
||||
// verify Copy() returns a copy, not simply a slice of the original vector
|
||||
const Len = 10
|
||||
|
|
|
@ -45,14 +45,14 @@ func NewCipher(key []byte) (*Cipher, os.Error) {
|
|||
|
||||
// BlockSize returns the AES block size, 16 bytes.
|
||||
// It is necessary to satisfy the Cipher interface in the
|
||||
// package "crypto/block".
|
||||
// package "crypto/cipher".
|
||||
func (c *Cipher) BlockSize() int { return BlockSize }
|
||||
|
||||
// Encrypt encrypts the 16-byte buffer src using the key k
|
||||
// and stores the result in dst.
|
||||
// Note that for amounts of data larger than a block,
|
||||
// it is not safe to just call Encrypt on successive blocks;
|
||||
// instead, use an encryption mode like CBC (see crypto/block/cbc.go).
|
||||
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
|
||||
func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.enc, dst, src) }
|
||||
|
||||
// Decrypt decrypts the 16-byte buffer src using the key k
|
||||
|
|
|
@ -42,14 +42,14 @@ func NewCipher(key []byte) (*Cipher, os.Error) {
|
|||
|
||||
// BlockSize returns the Blowfish block size, 8 bytes.
|
||||
// It is necessary to satisfy the Cipher interface in the
|
||||
// package "crypto/block".
|
||||
// package "crypto/cipher".
|
||||
func (c *Cipher) BlockSize() int { return BlockSize }
|
||||
|
||||
// Encrypt encrypts the 8-byte buffer src using the key k
|
||||
// and stores the result in dst.
|
||||
// Note that for amounts of data larger than a block,
|
||||
// it is not safe to just call Encrypt on successive blocks;
|
||||
// instead, use an encryption mode like CBC (see crypto/block/cbc.go).
|
||||
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
|
||||
func (c *Cipher) Encrypt(dst, src []byte) {
|
||||
l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
|
||||
r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])
|
||||
|
|
|
@ -20,7 +20,7 @@ type Cipher struct {
|
|||
|
||||
func NewCipher(key []byte) (c *Cipher, err os.Error) {
|
||||
if len(key) != KeySize {
|
||||
return nil, os.ErrorString("CAST5: keys must be 16 bytes")
|
||||
return nil, os.NewError("CAST5: keys must be 16 bytes")
|
||||
}
|
||||
|
||||
c = new(Cipher)
|
||||
|
|
|
@ -80,9 +80,10 @@ type ocfbDecrypter struct {
|
|||
// NewOCFBDecrypter returns a Stream which decrypts data with OpenPGP's cipher
|
||||
// feedback mode using the given Block. Prefix must be the first blockSize + 2
|
||||
// bytes of the ciphertext, where blockSize is the Block's block size. If an
|
||||
// incorrect key is detected then nil is returned. Resync determines if the
|
||||
// "resynchronization step" from RFC 4880, 13.9 step 7 is performed. Different
|
||||
// parts of OpenPGP vary on this point.
|
||||
// incorrect key is detected then nil is returned. On successful exit,
|
||||
// blockSize+2 bytes of decrypted data are written into prefix. Resync
|
||||
// determines if the "resynchronization step" from RFC 4880, 13.9 step 7 is
|
||||
// performed. Different parts of OpenPGP vary on this point.
|
||||
func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Stream {
|
||||
blockSize := block.BlockSize()
|
||||
if len(prefix) != blockSize+2 {
|
||||
|
@ -118,6 +119,7 @@ func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Strea
|
|||
x.fre[1] = prefix[blockSize+1]
|
||||
x.outUsed = 2
|
||||
}
|
||||
copy(prefix, prefixCopy)
|
||||
return x
|
||||
}
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ func GenerateParameters(params *Parameters, rand io.Reader, sizes ParameterSizes
|
|||
L = 3072
|
||||
N = 256
|
||||
default:
|
||||
return os.ErrorString("crypto/dsa: invalid ParameterSizes")
|
||||
return os.NewError("crypto/dsa: invalid ParameterSizes")
|
||||
}
|
||||
|
||||
qBytes := make([]byte, N/8)
|
||||
|
@ -158,7 +158,7 @@ GeneratePrimes:
|
|||
// PrivateKey must already be valid (see GenerateParameters).
|
||||
func GenerateKey(priv *PrivateKey, rand io.Reader) os.Error {
|
||||
if priv.P == nil || priv.Q == nil || priv.G == nil {
|
||||
return os.ErrorString("crypto/dsa: parameters not set up before generating key")
|
||||
return os.NewError("crypto/dsa: parameters not set up before generating key")
|
||||
}
|
||||
|
||||
x := new(big.Int)
|
||||
|
|
|
@ -284,7 +284,7 @@ func (curve *Curve) Marshal(x, y *big.Int) []byte {
|
|||
return ret
|
||||
}
|
||||
|
||||
// Unmarshal converts a point, serialised by Marshal, into an x, y pair. On
|
||||
// Unmarshal converts a point, serialized by Marshal, into an x, y pair. On
|
||||
// error, x = nil.
|
||||
func (curve *Curve) Unmarshal(data []byte) (x, y *big.Int) {
|
||||
byteLen := (curve.BitSize + 7) >> 3
|
||||
|
|
|
@ -321,8 +321,8 @@ func TestMarshal(t *testing.T) {
|
|||
t.Error(err)
|
||||
return
|
||||
}
|
||||
serialised := p224.Marshal(x, y)
|
||||
xx, yy := p224.Unmarshal(serialised)
|
||||
serialized := p224.Marshal(x, y)
|
||||
xx, yy := p224.Unmarshal(serialized)
|
||||
if xx == nil {
|
||||
t.Error("failed to unmarshal")
|
||||
return
|
||||
|
|
|
@ -190,7 +190,7 @@ func TestHMAC(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
|
||||
// Repetive Sum() calls should return the same value
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", h.Sum())
|
||||
if sum != tt.out {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"crypto/rsa"
|
||||
_ "crypto/sha1"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
@ -32,21 +33,8 @@ const (
|
|||
ocspUnauthorized = 5
|
||||
)
|
||||
|
||||
type rdnSequence []relativeDistinguishedNameSET
|
||||
|
||||
type relativeDistinguishedNameSET []attributeTypeAndValue
|
||||
|
||||
type attributeTypeAndValue struct {
|
||||
Type asn1.ObjectIdentifier
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
type algorithmIdentifier struct {
|
||||
Algorithm asn1.ObjectIdentifier
|
||||
}
|
||||
|
||||
type certID struct {
|
||||
HashAlgorithm algorithmIdentifier
|
||||
HashAlgorithm pkix.AlgorithmIdentifier
|
||||
NameHash []byte
|
||||
IssuerKeyHash []byte
|
||||
SerialNumber asn1.RawValue
|
||||
|
@ -54,7 +42,7 @@ type certID struct {
|
|||
|
||||
type responseASN1 struct {
|
||||
Status asn1.Enumerated
|
||||
Response responseBytes "explicit,tag:0"
|
||||
Response responseBytes `asn1:"explicit,tag:0"`
|
||||
}
|
||||
|
||||
type responseBytes struct {
|
||||
|
@ -64,32 +52,32 @@ type responseBytes struct {
|
|||
|
||||
type basicResponse struct {
|
||||
TBSResponseData responseData
|
||||
SignatureAlgorithm algorithmIdentifier
|
||||
SignatureAlgorithm pkix.AlgorithmIdentifier
|
||||
Signature asn1.BitString
|
||||
Certificates []asn1.RawValue "explicit,tag:0,optional"
|
||||
Certificates []asn1.RawValue `asn1:"explicit,tag:0,optional"`
|
||||
}
|
||||
|
||||
type responseData struct {
|
||||
Raw asn1.RawContent
|
||||
Version int "optional,default:1,explicit,tag:0"
|
||||
RequestorName rdnSequence "optional,explicit,tag:1"
|
||||
KeyHash []byte "optional,explicit,tag:2"
|
||||
Version int `asn1:"optional,default:1,explicit,tag:0"`
|
||||
RequestorName pkix.RDNSequence `asn1:"optional,explicit,tag:1"`
|
||||
KeyHash []byte `asn1:"optional,explicit,tag:2"`
|
||||
ProducedAt *time.Time
|
||||
Responses []singleResponse
|
||||
}
|
||||
|
||||
type singleResponse struct {
|
||||
CertID certID
|
||||
Good asn1.Flag "explicit,tag:0,optional"
|
||||
Revoked revokedInfo "explicit,tag:1,optional"
|
||||
Unknown asn1.Flag "explicit,tag:2,optional"
|
||||
Good asn1.Flag `asn1:"explicit,tag:0,optional"`
|
||||
Revoked revokedInfo `asn1:"explicit,tag:1,optional"`
|
||||
Unknown asn1.Flag `asn1:"explicit,tag:2,optional"`
|
||||
ThisUpdate *time.Time
|
||||
NextUpdate *time.Time "explicit,tag:0,optional"
|
||||
NextUpdate *time.Time `asn1:"explicit,tag:0,optional"`
|
||||
}
|
||||
|
||||
type revokedInfo struct {
|
||||
RevocationTime *time.Time
|
||||
Reason int "explicit,tag:0,optional"
|
||||
Reason int `asn1:"explicit,tag:0,optional"`
|
||||
}
|
||||
|
||||
// This is the exposed reflection of the internal OCSP structures.
|
||||
|
|
|
@ -153,7 +153,7 @@ func (r *openpgpReader) Read(p []byte) (n int, err os.Error) {
|
|||
|
||||
// Decode reads a PGP armored block from the given Reader. It will ignore
|
||||
// leading garbage. If it doesn't find a block, it will return nil, os.EOF. The
|
||||
// given Reader is not usable after calling this function: an arbitary amount
|
||||
// given Reader is not usable after calling this function: an arbitrary amount
|
||||
// of data may have been read past the end of the block.
|
||||
func Decode(in io.Reader) (p *Block, err os.Error) {
|
||||
r, _ := bufio.NewReaderSize(in, 100)
|
||||
|
|
|
@ -30,7 +30,6 @@ func (r recordingHash) Size() int {
|
|||
panic("shouldn't be called")
|
||||
}
|
||||
|
||||
|
||||
func testCanonicalText(t *testing.T, input, expected string) {
|
||||
r := recordingHash{bytes.NewBuffer(nil)}
|
||||
c := NewCanonicalTextHash(r)
|
||||
|
|
122
libgo/go/crypto/openpgp/elgamal/elgamal.go
Normal file
122
libgo/go/crypto/openpgp/elgamal/elgamal.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package elgamal implements ElGamal encryption, suitable for OpenPGP,
|
||||
// as specified in "A Public-Key Cryptosystem and a Signature Scheme Based on
|
||||
// Discrete Logarithms," IEEE Transactions on Information Theory, v. IT-31,
|
||||
// n. 4, 1985, pp. 469-472.
|
||||
//
|
||||
// This form of ElGamal embeds PKCS#1 v1.5 padding, which may make it
|
||||
// unsuitable for other protocols. RSA should be used in preference in any
|
||||
// case.
|
||||
package elgamal
|
||||
|
||||
import (
|
||||
"big"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// PublicKey represents an ElGamal public key.
|
||||
type PublicKey struct {
|
||||
G, P, Y *big.Int
|
||||
}
|
||||
|
||||
// PrivateKey represents an ElGamal private key.
|
||||
type PrivateKey struct {
|
||||
PublicKey
|
||||
X *big.Int
|
||||
}
|
||||
|
||||
// Encrypt encrypts the given message to the given public key. The result is a
|
||||
// pair of integers. Errors can result from reading random, or because msg is
|
||||
// too large to be encrypted to the public key.
|
||||
func Encrypt(random io.Reader, pub *PublicKey, msg []byte) (c1, c2 *big.Int, err os.Error) {
|
||||
pLen := (pub.P.BitLen() + 7) / 8
|
||||
if len(msg) > pLen-11 {
|
||||
err = os.NewError("elgamal: message too long")
|
||||
return
|
||||
}
|
||||
|
||||
// EM = 0x02 || PS || 0x00 || M
|
||||
em := make([]byte, pLen-1)
|
||||
em[0] = 2
|
||||
ps, mm := em[1:len(em)-len(msg)-1], em[len(em)-len(msg):]
|
||||
err = nonZeroRandomBytes(ps, random)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
em[len(em)-len(msg)-1] = 0
|
||||
copy(mm, msg)
|
||||
|
||||
m := new(big.Int).SetBytes(em)
|
||||
|
||||
k, err := rand.Int(random, pub.P)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c1 = new(big.Int).Exp(pub.G, k, pub.P)
|
||||
s := new(big.Int).Exp(pub.Y, k, pub.P)
|
||||
c2 = s.Mul(s, m)
|
||||
c2.Mod(c2, pub.P)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt takes two integers, resulting from an ElGamal encryption, and
|
||||
// returns the plaintext of the message. An error can result only if the
|
||||
// ciphertext is invalid. Users should keep in mind that this is a padding
|
||||
// oracle and thus, if exposed to an adaptive chosen ciphertext attack, can
|
||||
// be used to break the cryptosystem. See ``Chosen Ciphertext Attacks
|
||||
// Against Protocols Based on the RSA Encryption Standard PKCS #1'', Daniel
|
||||
// Bleichenbacher, Advances in Cryptology (Crypto '98),
|
||||
func Decrypt(priv *PrivateKey, c1, c2 *big.Int) (msg []byte, err os.Error) {
|
||||
s := new(big.Int).Exp(c1, priv.X, priv.P)
|
||||
s.ModInverse(s, priv.P)
|
||||
s.Mul(s, c2)
|
||||
s.Mod(s, priv.P)
|
||||
em := s.Bytes()
|
||||
|
||||
firstByteIsTwo := subtle.ConstantTimeByteEq(em[0], 2)
|
||||
|
||||
// The remainder of the plaintext must be a string of non-zero random
|
||||
// octets, followed by a 0, followed by the message.
|
||||
// lookingForIndex: 1 iff we are still looking for the zero.
|
||||
// index: the offset of the first zero byte.
|
||||
var lookingForIndex, index int
|
||||
lookingForIndex = 1
|
||||
|
||||
for i := 1; i < len(em); i++ {
|
||||
equals0 := subtle.ConstantTimeByteEq(em[i], 0)
|
||||
index = subtle.ConstantTimeSelect(lookingForIndex&equals0, i, index)
|
||||
lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex)
|
||||
}
|
||||
|
||||
if firstByteIsTwo != 1 || lookingForIndex != 0 || index < 9 {
|
||||
return nil, os.NewError("elgamal: decryption error")
|
||||
}
|
||||
return em[index+1:], nil
|
||||
}
|
||||
|
||||
// nonZeroRandomBytes fills the given slice with non-zero random octets.
|
||||
func nonZeroRandomBytes(s []byte, rand io.Reader) (err os.Error) {
|
||||
_, err = io.ReadFull(rand, s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
for s[i] == 0 {
|
||||
_, err = io.ReadFull(rand, s[i:i+1])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
49
libgo/go/crypto/openpgp/elgamal/elgamal_test.go
Normal file
49
libgo/go/crypto/openpgp/elgamal/elgamal_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package elgamal
|
||||
|
||||
import (
|
||||
"big"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// This is the 1024-bit MODP group from RFC 5114, section 2.1:
|
||||
const primeHex = "B10B8F96A080E01DDE92DE5EAE5D54EC52C99FBCFB06A3C69A6A9DCA52D23B616073E28675A23D189838EF1E2EE652C013ECB4AEA906112324975C3CD49B83BFACCBDD7D90C4BD7098488E9C219A73724EFFD6FAE5644738FAA31A4FF55BCCC0A151AF5F0DC8B4BD45BF37DF365C1A65E68CFDA76D4DA708DF1FB2BC2E4A4371"
|
||||
|
||||
const generatorHex = "A4D1CBD5C3FD34126765A442EFB99905F8104DD258AC507FD6406CFF14266D31266FEA1E5C41564B777E690F5504F213160217B4B01B886A5E91547F9E2749F4D7FBD7D3B9A92EE1909D0D2263F80A76A6A24C087A091F531DBF0A0169B6A28AD662A4D18E73AFA32D779D5918D08BC8858F4DCEF97C2A24855E6EEB22B3B2E5"
|
||||
|
||||
func fromHex(hex string) *big.Int {
|
||||
n, ok := new(big.Int).SetString(hex, 16)
|
||||
if !ok {
|
||||
panic("failed to parse hex number")
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
priv := &PrivateKey{
|
||||
PublicKey: PublicKey{
|
||||
G: fromHex(generatorHex),
|
||||
P: fromHex(primeHex),
|
||||
},
|
||||
X: fromHex("42"),
|
||||
}
|
||||
priv.Y = new(big.Int).Exp(priv.G, priv.X, priv.P)
|
||||
|
||||
message := []byte("hello world")
|
||||
c1, c2, err := Encrypt(rand.Reader, &priv.PublicKey, message)
|
||||
if err != nil {
|
||||
t.Errorf("error encrypting: %s", err)
|
||||
}
|
||||
message2, err := Decrypt(priv, c1, c2)
|
||||
if err != nil {
|
||||
t.Errorf("error decrypting: %s", err)
|
||||
}
|
||||
if !bytes.Equal(message2, message) {
|
||||
t.Errorf("decryption failed, got: %x, want: %x", message2, message)
|
||||
}
|
||||
}
|
|
@ -5,11 +5,14 @@
|
|||
package openpgp
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/openpgp/armor"
|
||||
"crypto/openpgp/error"
|
||||
"crypto/openpgp/packet"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PublicKeyType is the armor type for a PGP public key.
|
||||
|
@ -62,6 +65,78 @@ type KeyRing interface {
|
|||
DecryptionKeys() []Key
|
||||
}
|
||||
|
||||
// primaryIdentity returns the Identity marked as primary or the first identity
|
||||
// if none are so marked.
|
||||
func (e *Entity) primaryIdentity() *Identity {
|
||||
var firstIdentity *Identity
|
||||
for _, ident := range e.Identities {
|
||||
if firstIdentity == nil {
|
||||
firstIdentity = ident
|
||||
}
|
||||
if ident.SelfSignature.IsPrimaryId != nil && *ident.SelfSignature.IsPrimaryId {
|
||||
return ident
|
||||
}
|
||||
}
|
||||
return firstIdentity
|
||||
}
|
||||
|
||||
// encryptionKey returns the best candidate Key for encrypting a message to the
|
||||
// given Entity.
|
||||
func (e *Entity) encryptionKey() Key {
|
||||
candidateSubkey := -1
|
||||
|
||||
for i, subkey := range e.Subkeys {
|
||||
if subkey.Sig.FlagsValid && subkey.Sig.FlagEncryptCommunications && subkey.PublicKey.PubKeyAlgo.CanEncrypt() {
|
||||
candidateSubkey = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
i := e.primaryIdentity()
|
||||
|
||||
if e.PrimaryKey.PubKeyAlgo.CanEncrypt() {
|
||||
// If we don't have any candidate subkeys for encryption and
|
||||
// the primary key doesn't have any usage metadata then we
|
||||
// assume that the primary key is ok. Or, if the primary key is
|
||||
// marked as ok to encrypt to, then we can obviously use it.
|
||||
if candidateSubkey == -1 && !i.SelfSignature.FlagsValid || i.SelfSignature.FlagEncryptCommunications && i.SelfSignature.FlagsValid {
|
||||
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}
|
||||
}
|
||||
}
|
||||
|
||||
if candidateSubkey != -1 {
|
||||
subkey := e.Subkeys[candidateSubkey]
|
||||
return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}
|
||||
}
|
||||
|
||||
// This Entity appears to be signing only.
|
||||
return Key{}
|
||||
}
|
||||
|
||||
// signingKey return the best candidate Key for signing a message with this
|
||||
// Entity.
|
||||
func (e *Entity) signingKey() Key {
|
||||
candidateSubkey := -1
|
||||
|
||||
for i, subkey := range e.Subkeys {
|
||||
if subkey.Sig.FlagsValid && subkey.Sig.FlagSign && subkey.PublicKey.PubKeyAlgo.CanSign() {
|
||||
candidateSubkey = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
i := e.primaryIdentity()
|
||||
|
||||
// If we have no candidate subkey then we assume that it's ok to sign
|
||||
// with the primary key.
|
||||
if candidateSubkey == -1 || i.SelfSignature.FlagsValid && i.SelfSignature.FlagSign {
|
||||
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}
|
||||
}
|
||||
|
||||
subkey := e.Subkeys[candidateSubkey]
|
||||
return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}
|
||||
}
|
||||
|
||||
// An EntityList contains one or more Entities.
|
||||
type EntityList []*Entity
|
||||
|
||||
|
@ -197,6 +272,10 @@ func readEntity(packets *packet.Reader) (*Entity, os.Error) {
|
|||
}
|
||||
}
|
||||
|
||||
if !e.PrimaryKey.PubKeyAlgo.CanSign() {
|
||||
return nil, error.StructuralError("primary key cannot be used for signatures")
|
||||
}
|
||||
|
||||
var current *Identity
|
||||
EachPacket:
|
||||
for {
|
||||
|
@ -227,7 +306,7 @@ EachPacket:
|
|||
return nil, error.StructuralError("user ID packet not followed by self-signature")
|
||||
}
|
||||
|
||||
if sig.SigType == packet.SigTypePositiveCert && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
|
||||
if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
|
||||
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
|
||||
return nil, error.StructuralError("user ID self-signature invalid: " + err.String())
|
||||
}
|
||||
|
@ -297,3 +376,170 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p
|
|||
e.Subkeys = append(e.Subkeys, subKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
const defaultRSAKeyBits = 2048
|
||||
|
||||
// NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a
|
||||
// single identity composed of the given full name, comment and email, any of
|
||||
// which may be empty but must not contain any of "()<>\x00".
|
||||
func NewEntity(rand io.Reader, currentTimeSecs int64, name, comment, email string) (*Entity, os.Error) {
|
||||
uid := packet.NewUserId(name, comment, email)
|
||||
if uid == nil {
|
||||
return nil, error.InvalidArgumentError("user id field contained invalid characters")
|
||||
}
|
||||
signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encryptingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := uint32(currentTimeSecs)
|
||||
|
||||
e := &Entity{
|
||||
PrimaryKey: packet.NewRSAPublicKey(t, &signingPriv.PublicKey, false /* not a subkey */ ),
|
||||
PrivateKey: packet.NewRSAPrivateKey(t, signingPriv, false /* not a subkey */ ),
|
||||
Identities: make(map[string]*Identity),
|
||||
}
|
||||
isPrimaryId := true
|
||||
e.Identities[uid.Id] = &Identity{
|
||||
Name: uid.Name,
|
||||
UserId: uid,
|
||||
SelfSignature: &packet.Signature{
|
||||
CreationTime: t,
|
||||
SigType: packet.SigTypePositiveCert,
|
||||
PubKeyAlgo: packet.PubKeyAlgoRSA,
|
||||
Hash: crypto.SHA256,
|
||||
IsPrimaryId: &isPrimaryId,
|
||||
FlagsValid: true,
|
||||
FlagSign: true,
|
||||
FlagCertify: true,
|
||||
IssuerKeyId: &e.PrimaryKey.KeyId,
|
||||
},
|
||||
}
|
||||
|
||||
e.Subkeys = make([]Subkey, 1)
|
||||
e.Subkeys[0] = Subkey{
|
||||
PublicKey: packet.NewRSAPublicKey(t, &encryptingPriv.PublicKey, true /* is a subkey */ ),
|
||||
PrivateKey: packet.NewRSAPrivateKey(t, encryptingPriv, true /* is a subkey */ ),
|
||||
Sig: &packet.Signature{
|
||||
CreationTime: t,
|
||||
SigType: packet.SigTypeSubkeyBinding,
|
||||
PubKeyAlgo: packet.PubKeyAlgoRSA,
|
||||
Hash: crypto.SHA256,
|
||||
FlagsValid: true,
|
||||
FlagEncryptStorage: true,
|
||||
FlagEncryptCommunications: true,
|
||||
IssuerKeyId: &e.PrimaryKey.KeyId,
|
||||
},
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// SerializePrivate serializes an Entity, including private key material, to
|
||||
// the given Writer. For now, it must only be used on an Entity returned from
|
||||
// NewEntity.
|
||||
func (e *Entity) SerializePrivate(w io.Writer) (err os.Error) {
|
||||
err = e.PrivateKey.Serialize(w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, ident := range e.Identities {
|
||||
err = ident.UserId.Serialize(w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ident.SelfSignature.Serialize(w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, subkey := range e.Subkeys {
|
||||
err = subkey.PrivateKey.Serialize(w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = subkey.Sig.Serialize(w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serialize writes the public part of the given Entity to w. (No private
|
||||
// key material will be output).
|
||||
func (e *Entity) Serialize(w io.Writer) os.Error {
|
||||
err := e.PrimaryKey.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ident := range e.Identities {
|
||||
err = ident.UserId.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ident.SelfSignature.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, sig := range ident.Signatures {
|
||||
err = sig.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, subkey := range e.Subkeys {
|
||||
err = subkey.PublicKey.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = subkey.Sig.Serialize(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignIdentity adds a signature to e, from signer, attesting that identity is
|
||||
// associated with e. The provided identity must already be an element of
|
||||
// e.Identities and the private key of signer must have been decrypted if
|
||||
// necessary.
|
||||
func (e *Entity) SignIdentity(identity string, signer *Entity) os.Error {
|
||||
if signer.PrivateKey == nil {
|
||||
return error.InvalidArgumentError("signing Entity must have a private key")
|
||||
}
|
||||
if signer.PrivateKey.Encrypted {
|
||||
return error.InvalidArgumentError("signing Entity's private key must be decrypted")
|
||||
}
|
||||
ident, ok := e.Identities[identity]
|
||||
if !ok {
|
||||
return error.InvalidArgumentError("given identity string not found in Entity")
|
||||
}
|
||||
|
||||
sig := &packet.Signature{
|
||||
SigType: packet.SigTypeGenericCert,
|
||||
PubKeyAlgo: signer.PrivateKey.PubKeyAlgo,
|
||||
Hash: crypto.SHA256,
|
||||
CreationTime: uint32(time.Seconds()),
|
||||
IssuerKeyId: &signer.PrivateKey.KeyId,
|
||||
}
|
||||
if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil {
|
||||
return err
|
||||
}
|
||||
ident.Signatures = append(ident.Signatures, sig)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
package packet
|
||||
|
||||
import (
|
||||
"big"
|
||||
"crypto/openpgp/elgamal"
|
||||
"crypto/openpgp/error"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
|
@ -14,14 +16,17 @@ import (
|
|||
"strconv"
|
||||
)
|
||||
|
||||
const encryptedKeyVersion = 3
|
||||
|
||||
// EncryptedKey represents a public-key encrypted session key. See RFC 4880,
|
||||
// section 5.1.
|
||||
type EncryptedKey struct {
|
||||
KeyId uint64
|
||||
Algo PublicKeyAlgorithm
|
||||
Encrypted []byte
|
||||
CipherFunc CipherFunction // only valid after a successful Decrypt
|
||||
Key []byte // only valid after a successful Decrypt
|
||||
|
||||
encryptedMPI1, encryptedMPI2 []byte
|
||||
}
|
||||
|
||||
func (e *EncryptedKey) parse(r io.Reader) (err os.Error) {
|
||||
|
@ -30,37 +35,134 @@ func (e *EncryptedKey) parse(r io.Reader) (err os.Error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if buf[0] != 3 {
|
||||
if buf[0] != encryptedKeyVersion {
|
||||
return error.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
|
||||
}
|
||||
e.KeyId = binary.BigEndian.Uint64(buf[1:9])
|
||||
e.Algo = PublicKeyAlgorithm(buf[9])
|
||||
if e.Algo == PubKeyAlgoRSA || e.Algo == PubKeyAlgoRSAEncryptOnly {
|
||||
e.Encrypted, _, err = readMPI(r)
|
||||
switch e.Algo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
|
||||
e.encryptedMPI1, _, err = readMPI(r)
|
||||
case PubKeyAlgoElGamal:
|
||||
e.encryptedMPI1, _, err = readMPI(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e.encryptedMPI2, _, err = readMPI(r)
|
||||
}
|
||||
_, err = consumeAll(r)
|
||||
return
|
||||
}
|
||||
|
||||
// DecryptRSA decrypts an RSA encrypted session key with the given private key.
|
||||
func (e *EncryptedKey) DecryptRSA(priv *rsa.PrivateKey) (err os.Error) {
|
||||
if e.Algo != PubKeyAlgoRSA && e.Algo != PubKeyAlgoRSAEncryptOnly {
|
||||
return error.InvalidArgumentError("EncryptedKey not RSA encrypted")
|
||||
func checksumKeyMaterial(key []byte) uint16 {
|
||||
var checksum uint16
|
||||
for _, v := range key {
|
||||
checksum += uint16(v)
|
||||
}
|
||||
b, err := rsa.DecryptPKCS1v15(rand.Reader, priv, e.Encrypted)
|
||||
return checksum
|
||||
}
|
||||
|
||||
// Decrypt decrypts an encrypted session key with the given private key. The
|
||||
// private key must have been decrypted first.
|
||||
func (e *EncryptedKey) Decrypt(priv *PrivateKey) os.Error {
|
||||
var err os.Error
|
||||
var b []byte
|
||||
|
||||
// TODO(agl): use session key decryption routines here to avoid
|
||||
// padding oracle attacks.
|
||||
switch priv.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
|
||||
b, err = rsa.DecryptPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), e.encryptedMPI1)
|
||||
case PubKeyAlgoElGamal:
|
||||
c1 := new(big.Int).SetBytes(e.encryptedMPI1)
|
||||
c2 := new(big.Int).SetBytes(e.encryptedMPI2)
|
||||
b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
|
||||
default:
|
||||
err = error.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
e.CipherFunc = CipherFunction(b[0])
|
||||
e.Key = b[1 : len(b)-2]
|
||||
expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
|
||||
var checksum uint16
|
||||
for _, v := range e.Key {
|
||||
checksum += uint16(v)
|
||||
}
|
||||
checksum := checksumKeyMaterial(e.Key)
|
||||
if checksum != expectedChecksum {
|
||||
return error.StructuralError("EncryptedKey checksum incorrect")
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// SerializeEncryptedKey serializes an encrypted key packet to w that contains
|
||||
// key, encrypted to pub.
|
||||
func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFunc CipherFunction, key []byte) os.Error {
|
||||
var buf [10]byte
|
||||
buf[0] = encryptedKeyVersion
|
||||
binary.BigEndian.PutUint64(buf[1:9], pub.KeyId)
|
||||
buf[9] = byte(pub.PubKeyAlgo)
|
||||
|
||||
keyBlock := make([]byte, 1 /* cipher type */ +len(key)+2 /* checksum */ )
|
||||
keyBlock[0] = byte(cipherFunc)
|
||||
copy(keyBlock[1:], key)
|
||||
checksum := checksumKeyMaterial(key)
|
||||
keyBlock[1+len(key)] = byte(checksum >> 8)
|
||||
keyBlock[1+len(key)+1] = byte(checksum)
|
||||
|
||||
switch pub.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
|
||||
return serializeEncryptedKeyRSA(w, rand, buf, pub.PublicKey.(*rsa.PublicKey), keyBlock)
|
||||
case PubKeyAlgoElGamal:
|
||||
return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
|
||||
case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
|
||||
return error.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
|
||||
}
|
||||
|
||||
return error.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
|
||||
}
|
||||
|
||||
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) os.Error {
|
||||
cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
|
||||
if err != nil {
|
||||
return error.InvalidArgumentError("RSA encryption failed: " + err.String())
|
||||
}
|
||||
|
||||
packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
|
||||
|
||||
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(header[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeMPI(w, 8*uint16(len(cipherText)), cipherText)
|
||||
}
|
||||
|
||||
func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) os.Error {
|
||||
c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
|
||||
if err != nil {
|
||||
return error.InvalidArgumentError("ElGamal encryption failed: " + err.String())
|
||||
}
|
||||
|
||||
packetLen := 10 /* header length */
|
||||
packetLen += 2 /* mpi size */ + (c1.BitLen()+7)/8
|
||||
packetLen += 2 /* mpi size */ + (c2.BitLen()+7)/8
|
||||
|
||||
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(header[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writeBig(w, c1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeBig(w, c2)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ package packet
|
|||
|
||||
import (
|
||||
"big"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
@ -19,7 +21,27 @@ func bigFromBase10(s string) *big.Int {
|
|||
return b
|
||||
}
|
||||
|
||||
func TestEncryptedKey(t *testing.T) {
|
||||
var encryptedKeyPub = rsa.PublicKey{
|
||||
E: 65537,
|
||||
N: bigFromBase10("115804063926007623305902631768113868327816898845124614648849934718568541074358183759250136204762053879858102352159854352727097033322663029387610959884180306668628526686121021235757016368038585212410610742029286439607686208110250133174279811431933746643015923132833417396844716207301518956640020862630546868823"),
|
||||
}
|
||||
|
||||
var encryptedKeyRSAPriv = &rsa.PrivateKey{
|
||||
PublicKey: encryptedKeyPub,
|
||||
D: bigFromBase10("32355588668219869544751561565313228297765464314098552250409557267371233892496951383426602439009993875125222579159850054973310859166139474359774543943714622292329487391199285040721944491839695981199720170366763547754915493640685849961780092241140181198779299712578774460837139360803883139311171713302987058393"),
|
||||
}
|
||||
|
||||
var encryptedKeyPriv = &PrivateKey{
|
||||
PublicKey: PublicKey{
|
||||
PubKeyAlgo: PubKeyAlgoRSA,
|
||||
},
|
||||
PrivateKey: encryptedKeyRSAPriv,
|
||||
}
|
||||
|
||||
func TestDecryptingEncryptedKey(t *testing.T) {
|
||||
const encryptedKeyHex = "c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8"
|
||||
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b"
|
||||
|
||||
p, err := Read(readerFromHex(encryptedKeyHex))
|
||||
if err != nil {
|
||||
t.Errorf("error from Read: %s", err)
|
||||
|
@ -36,19 +58,9 @@ func TestEncryptedKey(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
pub := rsa.PublicKey{
|
||||
E: 65537,
|
||||
N: bigFromBase10("115804063926007623305902631768113868327816898845124614648849934718568541074358183759250136204762053879858102352159854352727097033322663029387610959884180306668628526686121021235757016368038585212410610742029286439607686208110250133174279811431933746643015923132833417396844716207301518956640020862630546868823"),
|
||||
}
|
||||
|
||||
priv := &rsa.PrivateKey{
|
||||
PublicKey: pub,
|
||||
D: bigFromBase10("32355588668219869544751561565313228297765464314098552250409557267371233892496951383426602439009993875125222579159850054973310859166139474359774543943714622292329487391199285040721944491839695981199720170366763547754915493640685849961780092241140181198779299712578774460837139360803883139311171713302987058393"),
|
||||
}
|
||||
|
||||
err = ek.DecryptRSA(priv)
|
||||
err = ek.Decrypt(encryptedKeyPriv)
|
||||
if err != nil {
|
||||
t.Errorf("error from DecryptRSA: %s", err)
|
||||
t.Errorf("error from Decrypt: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -63,5 +75,52 @@ func TestEncryptedKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
const encryptedKeyHex = "c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8"
|
||||
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b"
|
||||
func TestEncryptingEncryptedKey(t *testing.T) {
|
||||
key := []byte{1, 2, 3, 4}
|
||||
const expectedKeyHex = "01020304"
|
||||
const keyId = 42
|
||||
|
||||
pub := &PublicKey{
|
||||
PublicKey: &encryptedKeyPub,
|
||||
KeyId: keyId,
|
||||
PubKeyAlgo: PubKeyAlgoRSAEncryptOnly,
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err := SerializeEncryptedKey(buf, rand.Reader, pub, CipherAES128, key)
|
||||
if err != nil {
|
||||
t.Errorf("error writing encrypted key packet: %s", err)
|
||||
}
|
||||
|
||||
p, err := Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("error from Read: %s", err)
|
||||
return
|
||||
}
|
||||
ek, ok := p.(*EncryptedKey)
|
||||
if !ok {
|
||||
t.Errorf("didn't parse an EncryptedKey, got %#v", p)
|
||||
return
|
||||
}
|
||||
|
||||
if ek.KeyId != keyId || ek.Algo != PubKeyAlgoRSAEncryptOnly {
|
||||
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
|
||||
return
|
||||
}
|
||||
|
||||
err = ek.Decrypt(encryptedKeyPriv)
|
||||
if err != nil {
|
||||
t.Errorf("error from Decrypt: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if ek.CipherFunc != CipherAES128 {
|
||||
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
|
||||
return
|
||||
}
|
||||
|
||||
keyHex := fmt.Sprintf("%x", ek.Key)
|
||||
if keyHex != expectedKeyHex {
|
||||
t.Errorf("bad key, got %s want %x", keyHex, expectedKeyHex)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,3 +51,40 @@ func (l *LiteralData) parse(r io.Reader) (err os.Error) {
|
|||
l.Body = r
|
||||
return
|
||||
}
|
||||
|
||||
// SerializeLiteral serializes a literal data packet to w and returns a
|
||||
// WriteCloser to which the data itself can be written and which MUST be closed
|
||||
// on completion. The fileName is truncated to 255 bytes.
|
||||
func SerializeLiteral(w io.WriteCloser, isBinary bool, fileName string, time uint32) (plaintext io.WriteCloser, err os.Error) {
|
||||
var buf [4]byte
|
||||
buf[0] = 't'
|
||||
if isBinary {
|
||||
buf[0] = 'b'
|
||||
}
|
||||
if len(fileName) > 255 {
|
||||
fileName = fileName[:255]
|
||||
}
|
||||
buf[1] = byte(len(fileName))
|
||||
|
||||
inner, err := serializeStreamHeader(w, packetTypeLiteralData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = inner.Write(buf[:2])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = inner.Write([]byte(fileName))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
binary.BigEndian.PutUint32(buf[:], time)
|
||||
_, err = inner.Write(buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
plaintext = inner
|
||||
return
|
||||
}
|
||||
|
|
|
@ -24,6 +24,8 @@ type OnePassSignature struct {
|
|||
IsLast bool
|
||||
}
|
||||
|
||||
const onePassSignatureVersion = 3
|
||||
|
||||
func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
|
||||
var buf [13]byte
|
||||
|
||||
|
@ -31,7 +33,7 @@ func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if buf[0] != 3 {
|
||||
if buf[0] != onePassSignatureVersion {
|
||||
err = error.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
|
||||
}
|
||||
|
||||
|
@ -47,3 +49,26 @@ func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
|
|||
ops.IsLast = buf[12] != 0
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize marshals the given OnePassSignature to w.
|
||||
func (ops *OnePassSignature) Serialize(w io.Writer) os.Error {
|
||||
var buf [13]byte
|
||||
buf[0] = onePassSignatureVersion
|
||||
buf[1] = uint8(ops.SigType)
|
||||
var ok bool
|
||||
buf[2], ok = s2k.HashToHashId(ops.Hash)
|
||||
if !ok {
|
||||
return error.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
|
||||
}
|
||||
buf[3] = uint8(ops.PubKeyAlgo)
|
||||
binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
|
||||
if ops.IsLast {
|
||||
buf[12] = 1
|
||||
}
|
||||
|
||||
if err := serializeHeader(w, packetTypeOnePassSignature, len(buf)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write(buf[:])
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package packet implements parsing and serialisation of OpenPGP packets, as
|
||||
// Package packet implements parsing and serialization of OpenPGP packets, as
|
||||
// specified in RFC 4880.
|
||||
package packet
|
||||
|
||||
|
@ -92,6 +92,46 @@ func (r *partialLengthReader) Read(p []byte) (n int, err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
// partialLengthWriter writes a stream of data using OpenPGP partial lengths.
|
||||
// See RFC 4880, section 4.2.2.4.
|
||||
type partialLengthWriter struct {
|
||||
w io.WriteCloser
|
||||
lengthByte [1]byte
|
||||
}
|
||||
|
||||
func (w *partialLengthWriter) Write(p []byte) (n int, err os.Error) {
|
||||
for len(p) > 0 {
|
||||
for power := uint(14); power < 32; power-- {
|
||||
l := 1 << power
|
||||
if len(p) >= l {
|
||||
w.lengthByte[0] = 224 + uint8(power)
|
||||
_, err = w.w.Write(w.lengthByte[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var m int
|
||||
m, err = w.w.Write(p[:l])
|
||||
n += m
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p = p[l:]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (w *partialLengthWriter) Close() os.Error {
|
||||
w.lengthByte[0] = 0
|
||||
_, err := w.w.Write(w.lengthByte[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w.w.Close()
|
||||
}
|
||||
|
||||
// A spanReader is an io.LimitReader, but it returns ErrUnexpectedEOF if the
|
||||
// underlying Reader returns EOF before the limit has been reached.
|
||||
type spanReader struct {
|
||||
|
@ -195,6 +235,20 @@ func serializeHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
// serializeStreamHeader writes an OpenPGP packet header to w where the
|
||||
// length of the packet is unknown. It returns a io.WriteCloser which can be
|
||||
// used to write the contents of the packet. See RFC 4880, section 4.2.
|
||||
func serializeStreamHeader(w io.WriteCloser, ptype packetType) (out io.WriteCloser, err os.Error) {
|
||||
var buf [1]byte
|
||||
buf[0] = 0x80 | 0x40 | byte(ptype)
|
||||
_, err = w.Write(buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
out = &partialLengthWriter{w: w}
|
||||
return
|
||||
}
|
||||
|
||||
// Packet represents an OpenPGP packet. Users are expected to try casting
|
||||
// instances of this interface to specific packet types.
|
||||
type Packet interface {
|
||||
|
@ -301,12 +355,12 @@ type SignatureType uint8
|
|||
|
||||
const (
|
||||
SigTypeBinary SignatureType = 0
|
||||
SigTypeText = 1
|
||||
SigTypeGenericCert = 0x10
|
||||
SigTypePersonaCert = 0x11
|
||||
SigTypeCasualCert = 0x12
|
||||
SigTypePositiveCert = 0x13
|
||||
SigTypeSubkeyBinding = 0x18
|
||||
SigTypeText = 1
|
||||
SigTypeGenericCert = 0x10
|
||||
SigTypePersonaCert = 0x11
|
||||
SigTypeCasualCert = 0x12
|
||||
SigTypePositiveCert = 0x13
|
||||
SigTypeSubkeyBinding = 0x18
|
||||
)
|
||||
|
||||
// PublicKeyAlgorithm represents the different public key system specified for
|
||||
|
@ -318,23 +372,43 @@ const (
|
|||
PubKeyAlgoRSA PublicKeyAlgorithm = 1
|
||||
PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2
|
||||
PubKeyAlgoRSASignOnly PublicKeyAlgorithm = 3
|
||||
PubKeyAlgoElgamal PublicKeyAlgorithm = 16
|
||||
PubKeyAlgoElGamal PublicKeyAlgorithm = 16
|
||||
PubKeyAlgoDSA PublicKeyAlgorithm = 17
|
||||
)
|
||||
|
||||
// CanEncrypt returns true if it's possible to encrypt a message to a public
|
||||
// key of the given type.
|
||||
func (pka PublicKeyAlgorithm) CanEncrypt() bool {
|
||||
switch pka {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CanSign returns true if it's possible for a public key of the given type to
|
||||
// sign a message.
|
||||
func (pka PublicKeyAlgorithm) CanSign() bool {
|
||||
switch pka {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CipherFunction represents the different block ciphers specified for OpenPGP. See
|
||||
// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-13
|
||||
type CipherFunction uint8
|
||||
|
||||
const (
|
||||
CipherCAST5 = 3
|
||||
CipherAES128 = 7
|
||||
CipherAES192 = 8
|
||||
CipherAES256 = 9
|
||||
CipherCAST5 CipherFunction = 3
|
||||
CipherAES128 CipherFunction = 7
|
||||
CipherAES192 CipherFunction = 8
|
||||
CipherAES256 CipherFunction = 9
|
||||
)
|
||||
|
||||
// keySize returns the key size, in bytes, of cipher.
|
||||
func (cipher CipherFunction) keySize() int {
|
||||
// KeySize returns the key size, in bytes, of cipher.
|
||||
func (cipher CipherFunction) KeySize() int {
|
||||
switch cipher {
|
||||
case CipherCAST5:
|
||||
return cast5.KeySize
|
||||
|
@ -386,6 +460,14 @@ func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
// mpiLength returns the length of the given *big.Int when serialized as an
|
||||
// MPI.
|
||||
func mpiLength(n *big.Int) (mpiLengthInBytes int) {
|
||||
mpiLengthInBytes = 2 /* MPI length */
|
||||
mpiLengthInBytes += (n.BitLen() + 7) / 8
|
||||
return
|
||||
}
|
||||
|
||||
// writeMPI serializes a big integer to w.
|
||||
func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
|
||||
_, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})
|
||||
|
|
|
@ -210,3 +210,47 @@ func TestSerializeHeader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartialLengths(t *testing.T) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
w := new(partialLengthWriter)
|
||||
w.w = noOpCloser{buf}
|
||||
|
||||
const maxChunkSize = 64
|
||||
|
||||
var b [maxChunkSize]byte
|
||||
var n uint8
|
||||
for l := 1; l <= maxChunkSize; l++ {
|
||||
for i := 0; i < l; i++ {
|
||||
b[i] = n
|
||||
n++
|
||||
}
|
||||
m, err := w.Write(b[:l])
|
||||
if m != l {
|
||||
t.Errorf("short write got: %d want: %d", m, l)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("error from write: %s", err)
|
||||
}
|
||||
}
|
||||
w.Close()
|
||||
|
||||
want := (maxChunkSize * (maxChunkSize + 1)) / 2
|
||||
copyBuf := bytes.NewBuffer(nil)
|
||||
r := &partialLengthReader{buf, 0, true}
|
||||
m, err := io.Copy(copyBuf, r)
|
||||
if m != int64(want) {
|
||||
t.Errorf("short copy got: %d want: %d", m, want)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("error from copy: %s", err)
|
||||
}
|
||||
|
||||
copyBytes := copyBuf.Bytes()
|
||||
for i := 0; i < want; i++ {
|
||||
if copyBytes[i] != uint8(i) {
|
||||
t.Errorf("bad pattern in copy at %d", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"bytes"
|
||||
"crypto/cipher"
|
||||
"crypto/dsa"
|
||||
"crypto/openpgp/elgamal"
|
||||
"crypto/openpgp/error"
|
||||
"crypto/openpgp/s2k"
|
||||
"crypto/rsa"
|
||||
|
@ -32,6 +33,13 @@ type PrivateKey struct {
|
|||
iv []byte
|
||||
}
|
||||
|
||||
func NewRSAPrivateKey(currentTimeSecs uint32, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey {
|
||||
pk := new(PrivateKey)
|
||||
pk.PublicKey = *NewRSAPublicKey(currentTimeSecs, &priv.PublicKey, isSubkey)
|
||||
pk.PrivateKey = priv
|
||||
return pk
|
||||
}
|
||||
|
||||
func (pk *PrivateKey) parse(r io.Reader) (err os.Error) {
|
||||
err = (&pk.PublicKey).parse(r)
|
||||
if err != nil {
|
||||
|
@ -91,13 +99,90 @@ func (pk *PrivateKey) parse(r io.Reader) (err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
func mod64kHash(d []byte) uint16 {
|
||||
h := uint16(0)
|
||||
for i := 0; i < len(d); i += 2 {
|
||||
v := uint16(d[i]) << 8
|
||||
if i+1 < len(d) {
|
||||
v += uint16(d[i+1])
|
||||
}
|
||||
h += v
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (pk *PrivateKey) Serialize(w io.Writer) (err os.Error) {
|
||||
// TODO(agl): support encrypted private keys
|
||||
buf := bytes.NewBuffer(nil)
|
||||
err = pk.PublicKey.serializeWithoutHeaders(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buf.WriteByte(0 /* no encryption */ )
|
||||
|
||||
privateKeyBuf := bytes.NewBuffer(nil)
|
||||
|
||||
switch priv := pk.PrivateKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
err = serializeRSAPrivateKey(privateKeyBuf, priv)
|
||||
default:
|
||||
err = error.InvalidArgumentError("non-RSA private key")
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ptype := packetTypePrivateKey
|
||||
contents := buf.Bytes()
|
||||
privateKeyBytes := privateKeyBuf.Bytes()
|
||||
if pk.IsSubkey {
|
||||
ptype = packetTypePrivateSubkey
|
||||
}
|
||||
err = serializeHeader(w, ptype, len(contents)+len(privateKeyBytes)+2)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = w.Write(contents)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = w.Write(privateKeyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
checksum := mod64kHash(privateKeyBytes)
|
||||
var checksumBytes [2]byte
|
||||
checksumBytes[0] = byte(checksum >> 8)
|
||||
checksumBytes[1] = byte(checksum)
|
||||
_, err = w.Write(checksumBytes[:])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) os.Error {
|
||||
err := writeBig(w, priv.D)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writeBig(w, priv.Primes[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writeBig(w, priv.Primes[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeBig(w, priv.Precomputed.Qinv)
|
||||
}
|
||||
|
||||
// Decrypt decrypts an encrypted private key using a passphrase.
|
||||
func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error {
|
||||
if !pk.Encrypted {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := make([]byte, pk.cipher.keySize())
|
||||
key := make([]byte, pk.cipher.KeySize())
|
||||
pk.s2k(key, passphrase)
|
||||
block := pk.cipher.new(key)
|
||||
cfb := cipher.NewCFBDecrypter(block, pk.iv)
|
||||
|
@ -140,6 +225,8 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
|
|||
return pk.parseRSAPrivateKey(data)
|
||||
case PubKeyAlgoDSA:
|
||||
return pk.parseDSAPrivateKey(data)
|
||||
case PubKeyAlgoElGamal:
|
||||
return pk.parseElGamalPrivateKey(data)
|
||||
}
|
||||
panic("impossible")
|
||||
}
|
||||
|
@ -193,3 +280,22 @@ func (pk *PrivateKey) parseDSAPrivateKey(data []byte) (err os.Error) {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pk *PrivateKey) parseElGamalPrivateKey(data []byte) (err os.Error) {
|
||||
pub := pk.PublicKey.PublicKey.(*elgamal.PublicKey)
|
||||
priv := new(elgamal.PrivateKey)
|
||||
priv.PublicKey = *pub
|
||||
|
||||
buf := bytes.NewBuffer(data)
|
||||
x, _, err := readMPI(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
priv.X = new(big.Int).SetBytes(x)
|
||||
pk.PrivateKey = priv
|
||||
pk.Encrypted = false
|
||||
pk.encryptedData = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -8,30 +8,50 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
var privateKeyTests = []struct {
|
||||
privateKeyHex string
|
||||
creationTime uint32
|
||||
}{
|
||||
{
|
||||
privKeyRSAHex,
|
||||
0x4cc349a8,
|
||||
},
|
||||
{
|
||||
privKeyElGamalHex,
|
||||
0x4df9ee1a,
|
||||
},
|
||||
}
|
||||
|
||||
func TestPrivateKeyRead(t *testing.T) {
|
||||
packet, err := Read(readerFromHex(privKeyHex))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
for i, test := range privateKeyTests {
|
||||
packet, err := Read(readerFromHex(test.privateKeyHex))
|
||||
if err != nil {
|
||||
t.Errorf("#%d: failed to parse: %s", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
privKey := packet.(*PrivateKey)
|
||||
privKey := packet.(*PrivateKey)
|
||||
|
||||
if !privKey.Encrypted {
|
||||
t.Error("private key isn't encrypted")
|
||||
return
|
||||
}
|
||||
if !privKey.Encrypted {
|
||||
t.Errorf("#%d: private key isn't encrypted", i)
|
||||
continue
|
||||
}
|
||||
|
||||
err = privKey.Decrypt([]byte("testing"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
err = privKey.Decrypt([]byte("testing"))
|
||||
if err != nil {
|
||||
t.Errorf("#%d: failed to decrypt: %s", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if privKey.CreationTime != 0x4cc349a8 || privKey.Encrypted {
|
||||
t.Errorf("failed to parse, got: %#v", privKey)
|
||||
if privKey.CreationTime != test.creationTime || privKey.Encrypted {
|
||||
t.Errorf("#%d: bad result, got: %#v", i, privKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generated with `gpg --export-secret-keys "Test Key 2"`
|
||||
const privKeyHex = "9501fe044cc349a8010400b70ca0010e98c090008d45d1ee8f9113bd5861fd57b88bacb7c68658747663f1e1a3b5a98f32fda6472373c024b97359cd2efc88ff60f77751adfbf6af5e615e6a1408cfad8bf0cea30b0d5f53aa27ad59089ba9b15b7ebc2777a25d7b436144027e3bcd203909f147d0e332b240cf63d3395f5dfe0df0a6c04e8655af7eacdf0011010001fe0303024a252e7d475fd445607de39a265472aa74a9320ba2dac395faa687e9e0336aeb7e9a7397e511b5afd9dc84557c80ac0f3d4d7bfec5ae16f20d41c8c84a04552a33870b930420e230e179564f6d19bb153145e76c33ae993886c388832b0fa042ddda7f133924f3854481533e0ede31d51278c0519b29abc3bf53da673e13e3e1214b52413d179d7f66deee35cac8eacb060f78379d70ef4af8607e68131ff529439668fc39c9ce6dfef8a5ac234d234802cbfb749a26107db26406213ae5c06d4673253a3cbee1fcbae58d6ab77e38d6e2c0e7c6317c48e054edadb5a40d0d48acb44643d998139a8a66bb820be1f3f80185bc777d14b5954b60effe2448a036d565c6bc0b915fcea518acdd20ab07bc1529f561c58cd044f723109b93f6fd99f876ff891d64306b5d08f48bab59f38695e9109c4dec34013ba3153488ce070268381ba923ee1eb77125b36afcb4347ec3478c8f2735b06ef17351d872e577fa95d0c397c88c71b59629a36aec"
|
||||
const privKeyRSAHex = "9501fe044cc349a8010400b70ca0010e98c090008d45d1ee8f9113bd5861fd57b88bacb7c68658747663f1e1a3b5a98f32fda6472373c024b97359cd2efc88ff60f77751adfbf6af5e615e6a1408cfad8bf0cea30b0d5f53aa27ad59089ba9b15b7ebc2777a25d7b436144027e3bcd203909f147d0e332b240cf63d3395f5dfe0df0a6c04e8655af7eacdf0011010001fe0303024a252e7d475fd445607de39a265472aa74a9320ba2dac395faa687e9e0336aeb7e9a7397e511b5afd9dc84557c80ac0f3d4d7bfec5ae16f20d41c8c84a04552a33870b930420e230e179564f6d19bb153145e76c33ae993886c388832b0fa042ddda7f133924f3854481533e0ede31d51278c0519b29abc3bf53da673e13e3e1214b52413d179d7f66deee35cac8eacb060f78379d70ef4af8607e68131ff529439668fc39c9ce6dfef8a5ac234d234802cbfb749a26107db26406213ae5c06d4673253a3cbee1fcbae58d6ab77e38d6e2c0e7c6317c48e054edadb5a40d0d48acb44643d998139a8a66bb820be1f3f80185bc777d14b5954b60effe2448a036d565c6bc0b915fcea518acdd20ab07bc1529f561c58cd044f723109b93f6fd99f876ff891d64306b5d08f48bab59f38695e9109c4dec34013ba3153488ce070268381ba923ee1eb77125b36afcb4347ec3478c8f2735b06ef17351d872e577fa95d0c397c88c71b59629a36aec"
|
||||
|
||||
// Generated by `gpg --export-secret-keys` followed by a manual extraction of
|
||||
// the ElGamal subkey from the packets.
|
||||
const privKeyElGamalHex = "9d0157044df9ee1a100400eb8e136a58ec39b582629cdadf830bc64e0a94ed8103ca8bb247b27b11b46d1d25297ef4bcc3071785ba0c0bedfe89eabc5287fcc0edf81ab5896c1c8e4b20d27d79813c7aede75320b33eaeeaa586edc00fd1036c10133e6ba0ff277245d0d59d04b2b3421b7244aca5f4a8d870c6f1c1fbff9e1c26699a860b9504f35ca1d700030503fd1ededd3b840795be6d9ccbe3c51ee42e2f39233c432b831ddd9c4e72b7025a819317e47bf94f9ee316d7273b05d5fcf2999c3a681f519b1234bbfa6d359b4752bd9c3f77d6b6456cde152464763414ca130f4e91d91041432f90620fec0e6d6b5116076c2985d5aeaae13be492b9b329efcaf7ee25120159a0a30cd976b42d7afe030302dae7eb80db744d4960c4df930d57e87fe81412eaace9f900e6c839817a614ddb75ba6603b9417c33ea7b6c93967dfa2bcff3fa3c74a5ce2c962db65b03aece14c96cbd0038fc"
|
||||
|
|
|
@ -7,6 +7,7 @@ package packet
|
|||
import (
|
||||
"big"
|
||||
"crypto/dsa"
|
||||
"crypto/openpgp/elgamal"
|
||||
"crypto/openpgp/error"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
|
@ -30,6 +31,28 @@ type PublicKey struct {
|
|||
n, e, p, q, g, y parsedMPI
|
||||
}
|
||||
|
||||
func fromBig(n *big.Int) parsedMPI {
|
||||
return parsedMPI{
|
||||
bytes: n.Bytes(),
|
||||
bitLength: uint16(n.BitLen()),
|
||||
}
|
||||
}
|
||||
|
||||
// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
|
||||
func NewRSAPublicKey(creationTimeSecs uint32, pub *rsa.PublicKey, isSubkey bool) *PublicKey {
|
||||
pk := &PublicKey{
|
||||
CreationTime: creationTimeSecs,
|
||||
PubKeyAlgo: PubKeyAlgoRSA,
|
||||
PublicKey: pub,
|
||||
IsSubkey: isSubkey,
|
||||
n: fromBig(pub.N),
|
||||
e: fromBig(big.NewInt(int64(pub.E))),
|
||||
}
|
||||
|
||||
pk.setFingerPrintAndKeyId()
|
||||
return pk
|
||||
}
|
||||
|
||||
func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
|
||||
// RFC 4880, section 5.5.2
|
||||
var buf [6]byte
|
||||
|
@ -47,6 +70,8 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
|
|||
err = pk.parseRSA(r)
|
||||
case PubKeyAlgoDSA:
|
||||
err = pk.parseDSA(r)
|
||||
case PubKeyAlgoElGamal:
|
||||
err = pk.parseElGamal(r)
|
||||
default:
|
||||
err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
|
||||
}
|
||||
|
@ -54,14 +79,17 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
pk.setFingerPrintAndKeyId()
|
||||
return
|
||||
}
|
||||
|
||||
func (pk *PublicKey) setFingerPrintAndKeyId() {
|
||||
// RFC 4880, section 12.2
|
||||
fingerPrint := sha1.New()
|
||||
pk.SerializeSignaturePrefix(fingerPrint)
|
||||
pk.Serialize(fingerPrint)
|
||||
pk.serializeWithoutHeaders(fingerPrint)
|
||||
copy(pk.Fingerprint[:], fingerPrint.Sum())
|
||||
pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// parseRSA parses RSA public key material from the given Reader. See RFC 4880,
|
||||
|
@ -92,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
// parseRSA parses DSA public key material from the given Reader. See RFC 4880,
|
||||
// parseDSA parses DSA public key material from the given Reader. See RFC 4880,
|
||||
// section 5.5.2.
|
||||
func (pk *PublicKey) parseDSA(r io.Reader) (err os.Error) {
|
||||
pk.p.bytes, pk.p.bitLength, err = readMPI(r)
|
||||
|
@ -121,6 +149,30 @@ func (pk *PublicKey) parseDSA(r io.Reader) (err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
// parseElGamal parses ElGamal public key material from the given Reader. See
|
||||
// RFC 4880, section 5.5.2.
|
||||
func (pk *PublicKey) parseElGamal(r io.Reader) (err os.Error) {
|
||||
pk.p.bytes, pk.p.bitLength, err = readMPI(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pk.g.bytes, pk.g.bitLength, err = readMPI(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pk.y.bytes, pk.y.bitLength, err = readMPI(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
elgamal := new(elgamal.PublicKey)
|
||||
elgamal.P = new(big.Int).SetBytes(pk.p.bytes)
|
||||
elgamal.G = new(big.Int).SetBytes(pk.g.bytes)
|
||||
elgamal.Y = new(big.Int).SetBytes(pk.y.bytes)
|
||||
pk.PublicKey = elgamal
|
||||
return
|
||||
}
|
||||
|
||||
// SerializeSignaturePrefix writes the prefix for this public key to the given Writer.
|
||||
// The prefix is used when calculating a signature over this public key. See
|
||||
// RFC 4880, section 5.2.4.
|
||||
|
@ -135,6 +187,10 @@ func (pk *PublicKey) SerializeSignaturePrefix(h hash.Hash) {
|
|||
pLength += 2 + uint16(len(pk.q.bytes))
|
||||
pLength += 2 + uint16(len(pk.g.bytes))
|
||||
pLength += 2 + uint16(len(pk.y.bytes))
|
||||
case PubKeyAlgoElGamal:
|
||||
pLength += 2 + uint16(len(pk.p.bytes))
|
||||
pLength += 2 + uint16(len(pk.g.bytes))
|
||||
pLength += 2 + uint16(len(pk.y.bytes))
|
||||
default:
|
||||
panic("unknown public key algorithm")
|
||||
}
|
||||
|
@ -143,9 +199,40 @@ func (pk *PublicKey) SerializeSignaturePrefix(h hash.Hash) {
|
|||
return
|
||||
}
|
||||
|
||||
// Serialize marshals the PublicKey to w in the form of an OpenPGP public key
|
||||
// packet, not including the packet header.
|
||||
func (pk *PublicKey) Serialize(w io.Writer) (err os.Error) {
|
||||
length := 6 // 6 byte header
|
||||
|
||||
switch pk.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
|
||||
length += 2 + len(pk.n.bytes)
|
||||
length += 2 + len(pk.e.bytes)
|
||||
case PubKeyAlgoDSA:
|
||||
length += 2 + len(pk.p.bytes)
|
||||
length += 2 + len(pk.q.bytes)
|
||||
length += 2 + len(pk.g.bytes)
|
||||
length += 2 + len(pk.y.bytes)
|
||||
case PubKeyAlgoElGamal:
|
||||
length += 2 + len(pk.p.bytes)
|
||||
length += 2 + len(pk.g.bytes)
|
||||
length += 2 + len(pk.y.bytes)
|
||||
default:
|
||||
panic("unknown public key algorithm")
|
||||
}
|
||||
|
||||
packetType := packetTypePublicKey
|
||||
if pk.IsSubkey {
|
||||
packetType = packetTypePublicSubkey
|
||||
}
|
||||
err = serializeHeader(w, packetType, length)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return pk.serializeWithoutHeaders(w)
|
||||
}
|
||||
|
||||
// serializeWithoutHeaders marshals the PublicKey to w in the form of an
|
||||
// OpenPGP public key packet, not including the packet header.
|
||||
func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err os.Error) {
|
||||
var buf [6]byte
|
||||
buf[0] = 4
|
||||
buf[1] = byte(pk.CreationTime >> 24)
|
||||
|
@ -164,13 +251,15 @@ func (pk *PublicKey) Serialize(w io.Writer) (err os.Error) {
|
|||
return writeMPIs(w, pk.n, pk.e)
|
||||
case PubKeyAlgoDSA:
|
||||
return writeMPIs(w, pk.p, pk.q, pk.g, pk.y)
|
||||
case PubKeyAlgoElGamal:
|
||||
return writeMPIs(w, pk.p, pk.g, pk.y)
|
||||
}
|
||||
return error.InvalidArgumentError("bad public-key algorithm")
|
||||
}
|
||||
|
||||
// CanSign returns true iff this public key can generate signatures
|
||||
func (pk *PublicKey) CanSign() bool {
|
||||
return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElgamal
|
||||
return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElGamal
|
||||
}
|
||||
|
||||
// VerifySignature returns nil iff sig is a valid signature, made by this
|
||||
|
@ -194,14 +283,14 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
|
|||
switch pk.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
|
||||
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
|
||||
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature)
|
||||
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes)
|
||||
if err != nil {
|
||||
return error.SignatureError("RSA verification failure")
|
||||
}
|
||||
return nil
|
||||
case PubKeyAlgoDSA:
|
||||
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
|
||||
if !dsa.Verify(dsaPublicKey, hashBytes, sig.DSASigR, sig.DSASigS) {
|
||||
if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
|
||||
return error.SignatureError("DSA verification failure")
|
||||
}
|
||||
return nil
|
||||
|
@ -211,34 +300,43 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
|
|||
panic("unreachable")
|
||||
}
|
||||
|
||||
// VerifyKeySignature returns nil iff sig is a valid signature, make by this
|
||||
// public key, of the public key in signed.
|
||||
func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err os.Error) {
|
||||
h := sig.Hash.New()
|
||||
// keySignatureHash returns a Hash of the message that needs to be signed for
|
||||
// pk to assert a subkey relationship to signed.
|
||||
func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err os.Error) {
|
||||
h = sig.Hash.New()
|
||||
if h == nil {
|
||||
return error.UnsupportedError("hash function")
|
||||
return nil, error.UnsupportedError("hash function")
|
||||
}
|
||||
|
||||
// RFC 4880, section 5.2.4
|
||||
pk.SerializeSignaturePrefix(h)
|
||||
pk.Serialize(h)
|
||||
pk.serializeWithoutHeaders(h)
|
||||
signed.SerializeSignaturePrefix(h)
|
||||
signed.Serialize(h)
|
||||
signed.serializeWithoutHeaders(h)
|
||||
return
|
||||
}
|
||||
|
||||
// VerifyKeySignature returns nil iff sig is a valid signature, made by this
|
||||
// public key, of signed.
|
||||
func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err os.Error) {
|
||||
h, err := keySignatureHash(pk, signed, sig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pk.VerifySignature(h, sig)
|
||||
}
|
||||
|
||||
// VerifyUserIdSignature returns nil iff sig is a valid signature, make by this
|
||||
// public key, of the given user id.
|
||||
func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Error) {
|
||||
h := sig.Hash.New()
|
||||
// userIdSignatureHash returns a Hash of the message that needs to be signed
|
||||
// to assert that pk is a valid key for id.
|
||||
func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err os.Error) {
|
||||
h = sig.Hash.New()
|
||||
if h == nil {
|
||||
return error.UnsupportedError("hash function")
|
||||
return nil, error.UnsupportedError("hash function")
|
||||
}
|
||||
|
||||
// RFC 4880, section 5.2.4
|
||||
pk.SerializeSignaturePrefix(h)
|
||||
pk.Serialize(h)
|
||||
pk.serializeWithoutHeaders(h)
|
||||
|
||||
var buf [5]byte
|
||||
buf[0] = 0xb4
|
||||
|
@ -249,6 +347,16 @@ func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Er
|
|||
h.Write(buf[:])
|
||||
h.Write([]byte(id))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// VerifyUserIdSignature returns nil iff sig is a valid signature, made by this
|
||||
// public key, of id.
|
||||
func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Error) {
|
||||
h, err := userIdSignatureHash(id, pk, sig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pk.VerifySignature(h, sig)
|
||||
}
|
||||
|
||||
|
@ -272,7 +380,7 @@ type parsedMPI struct {
|
|||
bitLength uint16
|
||||
}
|
||||
|
||||
// writeMPIs is a utility function for serialising several big integers to the
|
||||
// writeMPIs is a utility function for serializing several big integers to the
|
||||
// given Writer.
|
||||
func writeMPIs(w io.Writer, mpis ...parsedMPI) (err os.Error) {
|
||||
for _, mpi := range mpis {
|
||||
|
|
|
@ -28,12 +28,12 @@ func TestPublicKeyRead(t *testing.T) {
|
|||
packet, err := Read(readerFromHex(test.hexData))
|
||||
if err != nil {
|
||||
t.Errorf("#%d: Read error: %s", i, err)
|
||||
return
|
||||
continue
|
||||
}
|
||||
pk, ok := packet.(*PublicKey)
|
||||
if !ok {
|
||||
t.Errorf("#%d: failed to parse, got: %#v", i, packet)
|
||||
return
|
||||
continue
|
||||
}
|
||||
if pk.PubKeyAlgo != test.pubKeyAlgo {
|
||||
t.Errorf("#%d: bad public key algorithm got:%x want:%x", i, pk.PubKeyAlgo, test.pubKeyAlgo)
|
||||
|
@ -57,6 +57,38 @@ func TestPublicKeyRead(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPublicKeySerialize(t *testing.T) {
|
||||
for i, test := range pubKeyTests {
|
||||
packet, err := Read(readerFromHex(test.hexData))
|
||||
if err != nil {
|
||||
t.Errorf("#%d: Read error: %s", i, err)
|
||||
continue
|
||||
}
|
||||
pk, ok := packet.(*PublicKey)
|
||||
if !ok {
|
||||
t.Errorf("#%d: failed to parse, got: %#v", i, packet)
|
||||
continue
|
||||
}
|
||||
serializeBuf := bytes.NewBuffer(nil)
|
||||
err = pk.Serialize(serializeBuf)
|
||||
if err != nil {
|
||||
t.Errorf("#%d: failed to serialize: %s", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
packet, err = Read(serializeBuf)
|
||||
if err != nil {
|
||||
t.Errorf("#%d: Read error (from serialized data): %s", i, err)
|
||||
continue
|
||||
}
|
||||
pk, ok = packet.(*PublicKey)
|
||||
if !ok {
|
||||
t.Errorf("#%d: failed to parse serialized data, got: %#v", i, packet)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const rsaFingerprintHex = "5fb74b1d03b1e3cb31bc2f8aa34d7e18c20c31bb"
|
||||
|
||||
const rsaPkDataHex = "988d044d3c5c10010400b1d13382944bd5aba23a4312968b5095d14f947f600eb478e14a6fcb16b0e0cac764884909c020bc495cfcc39a935387c661507bdb236a0612fb582cac3af9b29cc2c8c70090616c41b662f4da4c1201e195472eb7f4ae1ccbcbf9940fe21d985e379a5563dde5b9a23d35f1cfaa5790da3b79db26f23695107bfaca8e7b5bcd0011010001"
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package packet
|
||||
|
||||
import (
|
||||
"big"
|
||||
"crypto"
|
||||
"crypto/dsa"
|
||||
"crypto/openpgp/error"
|
||||
|
@ -32,8 +31,11 @@ type Signature struct {
|
|||
HashTag [2]byte
|
||||
CreationTime uint32 // Unix epoch time
|
||||
|
||||
RSASignature []byte
|
||||
DSASigR, DSASigS *big.Int
|
||||
RSASignature parsedMPI
|
||||
DSASigR, DSASigS parsedMPI
|
||||
|
||||
// rawSubpackets contains the unparsed subpackets, in order.
|
||||
rawSubpackets []outputSubpacket
|
||||
|
||||
// The following are optional so are nil when not included in the
|
||||
// signature.
|
||||
|
@ -128,14 +130,11 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
|
|||
|
||||
switch sig.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
|
||||
sig.RSASignature, _, err = readMPI(r)
|
||||
sig.RSASignature.bytes, sig.RSASignature.bitLength, err = readMPI(r)
|
||||
case PubKeyAlgoDSA:
|
||||
var rBytes, sBytes []byte
|
||||
rBytes, _, err = readMPI(r)
|
||||
sig.DSASigR = new(big.Int).SetBytes(rBytes)
|
||||
sig.DSASigR.bytes, sig.DSASigR.bitLength, err = readMPI(r)
|
||||
if err == nil {
|
||||
sBytes, _, err = readMPI(r)
|
||||
sig.DSASigS = new(big.Int).SetBytes(sBytes)
|
||||
sig.DSASigS.bytes, sig.DSASigS.bitLength, err = readMPI(r)
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
|
@ -177,7 +176,11 @@ const (
|
|||
// parseSignatureSubpacket parses a single subpacket. len(subpacket) is >= 1.
|
||||
func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (rest []byte, err os.Error) {
|
||||
// RFC 4880, section 5.2.3.1
|
||||
var length uint32
|
||||
var (
|
||||
length uint32
|
||||
packetType signatureSubpacketType
|
||||
isCritical bool
|
||||
)
|
||||
switch {
|
||||
case subpacket[0] < 192:
|
||||
length = uint32(subpacket[0])
|
||||
|
@ -207,10 +210,11 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
|
|||
err = error.StructuralError("zero length signature subpacket")
|
||||
return
|
||||
}
|
||||
packetType := subpacket[0] & 0x7f
|
||||
isCritial := subpacket[0]&0x80 == 0x80
|
||||
packetType = signatureSubpacketType(subpacket[0] & 0x7f)
|
||||
isCritical = subpacket[0]&0x80 == 0x80
|
||||
subpacket = subpacket[1:]
|
||||
switch signatureSubpacketType(packetType) {
|
||||
sig.rawSubpackets = append(sig.rawSubpackets, outputSubpacket{isHashed, packetType, isCritical, subpacket})
|
||||
switch packetType {
|
||||
case creationTimeSubpacket:
|
||||
if !isHashed {
|
||||
err = error.StructuralError("signature creation time in non-hashed area")
|
||||
|
@ -309,7 +313,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
|
|||
}
|
||||
|
||||
default:
|
||||
if isCritial {
|
||||
if isCritical {
|
||||
err = error.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))
|
||||
return
|
||||
}
|
||||
|
@ -381,7 +385,6 @@ func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
|
|||
|
||||
// buildHashSuffix constructs the HashSuffix member of sig in preparation for signing.
|
||||
func (sig *Signature) buildHashSuffix() (err os.Error) {
|
||||
sig.outSubpackets = sig.buildSubpackets()
|
||||
hashedSubpacketsLen := subpacketsLength(sig.outSubpackets, true)
|
||||
|
||||
var ok bool
|
||||
|
@ -393,7 +396,7 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
|
|||
sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash)
|
||||
if !ok {
|
||||
sig.HashSuffix = nil
|
||||
return error.InvalidArgumentError("hash cannot be repesented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
|
||||
return error.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
|
||||
}
|
||||
sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
|
||||
sig.HashSuffix[5] = byte(hashedSubpacketsLen)
|
||||
|
@ -420,45 +423,72 @@ func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err os.Error)
|
|||
return
|
||||
}
|
||||
|
||||
// SignRSA signs a message with an RSA private key. The hash, h, must contain
|
||||
// Sign signs a message with a private key. The hash, h, must contain
|
||||
// the hash of the message to be signed and will be mutated by this function.
|
||||
// On success, the signature is stored in sig. Call Serialize to write it out.
|
||||
func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
|
||||
func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err os.Error) {
|
||||
sig.outSubpackets = sig.buildSubpackets()
|
||||
digest, err := sig.signPrepareHash(h)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sig.RSASignature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
|
||||
|
||||
switch priv.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
|
||||
sig.RSASignature.bytes, err = rsa.SignPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), sig.Hash, digest)
|
||||
sig.RSASignature.bitLength = uint16(8 * len(sig.RSASignature.bytes))
|
||||
case PubKeyAlgoDSA:
|
||||
r, s, err := dsa.Sign(rand.Reader, priv.PrivateKey.(*dsa.PrivateKey), digest)
|
||||
if err == nil {
|
||||
sig.DSASigR.bytes = r.Bytes()
|
||||
sig.DSASigR.bitLength = uint16(8 * len(sig.DSASigR.bytes))
|
||||
sig.DSASigS.bytes = s.Bytes()
|
||||
sig.DSASigS.bitLength = uint16(8 * len(sig.DSASigS.bytes))
|
||||
}
|
||||
default:
|
||||
err = error.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SignDSA signs a message with a DSA private key. The hash, h, must contain
|
||||
// the hash of the message to be signed and will be mutated by this function.
|
||||
// On success, the signature is stored in sig. Call Serialize to write it out.
|
||||
func (sig *Signature) SignDSA(h hash.Hash, priv *dsa.PrivateKey) (err os.Error) {
|
||||
digest, err := sig.signPrepareHash(h)
|
||||
// SignUserId computes a signature from priv, asserting that pub is a valid
|
||||
// key for the identity id. On success, the signature is stored in sig. Call
|
||||
// Serialize to write it out.
|
||||
func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey) os.Error {
|
||||
h, err := userIdSignatureHash(id, pub, sig)
|
||||
if err != nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
sig.DSASigR, sig.DSASigS, err = dsa.Sign(rand.Reader, priv, digest)
|
||||
return
|
||||
return sig.Sign(h, priv)
|
||||
}
|
||||
|
||||
// SignKey computes a signature from priv, asserting that pub is a subkey. On
|
||||
// success, the signature is stored in sig. Call Serialize to write it out.
|
||||
func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey) os.Error {
|
||||
h, err := keySignatureHash(&priv.PublicKey, pub, sig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sig.Sign(h, priv)
|
||||
}
|
||||
|
||||
// Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
|
||||
func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
|
||||
if sig.RSASignature == nil && sig.DSASigR == nil {
|
||||
if len(sig.outSubpackets) == 0 {
|
||||
sig.outSubpackets = sig.rawSubpackets
|
||||
}
|
||||
if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil {
|
||||
return error.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
|
||||
}
|
||||
|
||||
sigLength := 0
|
||||
switch sig.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
|
||||
sigLength = len(sig.RSASignature)
|
||||
sigLength = 2 + len(sig.RSASignature.bytes)
|
||||
case PubKeyAlgoDSA:
|
||||
sigLength = 2 /* MPI length */
|
||||
sigLength += (sig.DSASigR.BitLen() + 7) / 8
|
||||
sigLength += 2 /* MPI length */
|
||||
sigLength += (sig.DSASigS.BitLen() + 7) / 8
|
||||
sigLength = 2 + len(sig.DSASigR.bytes)
|
||||
sigLength += 2 + len(sig.DSASigS.bytes)
|
||||
default:
|
||||
panic("impossible")
|
||||
}
|
||||
|
@ -466,7 +496,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
|
|||
unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false)
|
||||
length := len(sig.HashSuffix) - 6 /* trailer not included */ +
|
||||
2 /* length of unhashed subpackets */ + unhashedSubpacketsLen +
|
||||
2 /* hash tag */ + 2 /* length of signature MPI */ + sigLength
|
||||
2 /* hash tag */ + sigLength
|
||||
err = serializeHeader(w, packetTypeSignature, length)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -493,12 +523,9 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
|
|||
|
||||
switch sig.PubKeyAlgo {
|
||||
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
|
||||
err = writeMPI(w, 8*uint16(len(sig.RSASignature)), sig.RSASignature)
|
||||
err = writeMPIs(w, sig.RSASignature)
|
||||
case PubKeyAlgoDSA:
|
||||
err = writeBig(w, sig.DSASigR)
|
||||
if err == nil {
|
||||
err = writeBig(w, sig.DSASigS)
|
||||
}
|
||||
err = writeMPIs(w, sig.DSASigR, sig.DSASigS)
|
||||
default:
|
||||
panic("impossible")
|
||||
}
|
||||
|
@ -509,6 +536,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
|
|||
type outputSubpacket struct {
|
||||
hashed bool // true if this subpacket is in the hashed area.
|
||||
subpacketType signatureSubpacketType
|
||||
isCritical bool
|
||||
contents []byte
|
||||
}
|
||||
|
||||
|
@ -518,12 +546,12 @@ func (sig *Signature) buildSubpackets() (subpackets []outputSubpacket) {
|
|||
creationTime[1] = byte(sig.CreationTime >> 16)
|
||||
creationTime[2] = byte(sig.CreationTime >> 8)
|
||||
creationTime[3] = byte(sig.CreationTime)
|
||||
subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, creationTime})
|
||||
subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, false, creationTime})
|
||||
|
||||
if sig.IssuerKeyId != nil {
|
||||
keyId := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(keyId, *sig.IssuerKeyId)
|
||||
subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, keyId})
|
||||
subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId})
|
||||
}
|
||||
|
||||
return
|
||||
|
|
|
@ -12,9 +12,7 @@ import (
|
|||
)
|
||||
|
||||
func TestSignatureRead(t *testing.T) {
|
||||
signatureData, _ := hex.DecodeString(signatureDataHex)
|
||||
buf := bytes.NewBuffer(signatureData)
|
||||
packet, err := Read(buf)
|
||||
packet, err := Read(readerFromHex(signatureDataHex))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -25,4 +23,20 @@ func TestSignatureRead(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
const signatureDataHex = "89011c04000102000605024cb45112000a0910ab105c91af38fb158f8d07ff5596ea368c5efe015bed6e78348c0f033c931d5f2ce5db54ce7f2a7e4b4ad64db758d65a7a71773edeab7ba2a9e0908e6a94a1175edd86c1d843279f045b021a6971a72702fcbd650efc393c5474d5b59a15f96d2eaad4c4c426797e0dcca2803ef41c6ff234d403eec38f31d610c344c06f2401c262f0993b2e66cad8a81ebc4322c723e0d4ba09fe917e8777658307ad8329adacba821420741009dfe87f007759f0982275d028a392c6ed983a0d846f890b36148c7358bdb8a516007fac760261ecd06076813831a36d0459075d1befa245ae7f7fb103d92ca759e9498fe60ef8078a39a3beda510deea251ea9f0a7f0df6ef42060f20780360686f3e400e"
|
||||
func TestSignatureReserialize(t *testing.T) {
|
||||
packet, _ := Read(readerFromHex(signatureDataHex))
|
||||
sig := packet.(*Signature)
|
||||
out := new(bytes.Buffer)
|
||||
err := sig.Serialize(out)
|
||||
if err != nil {
|
||||
t.Errorf("error reserializing: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
expected, _ := hex.DecodeString(signatureDataHex)
|
||||
if !bytes.Equal(expected, out.Bytes()) {
|
||||
t.Errorf("output doesn't match input (got vs expected):\n%s\n%s", hex.Dump(out.Bytes()), hex.Dump(expected))
|
||||
}
|
||||
}
|
||||
|
||||
const signatureDataHex = "c2c05c04000102000605024cb45112000a0910ab105c91af38fb158f8d07ff5596ea368c5efe015bed6e78348c0f033c931d5f2ce5db54ce7f2a7e4b4ad64db758d65a7a71773edeab7ba2a9e0908e6a94a1175edd86c1d843279f045b021a6971a72702fcbd650efc393c5474d5b59a15f96d2eaad4c4c426797e0dcca2803ef41c6ff234d403eec38f31d610c344c06f2401c262f0993b2e66cad8a81ebc4322c723e0d4ba09fe917e8777658307ad8329adacba821420741009dfe87f007759f0982275d028a392c6ed983a0d846f890b36148c7358bdb8a516007fac760261ecd06076813831a36d0459075d1befa245ae7f7fb103d92ca759e9498fe60ef8078a39a3beda510deea251ea9f0a7f0df6ef42060f20780360686f3e400e"
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package packet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"crypto/openpgp/error"
|
||||
"crypto/openpgp/s2k"
|
||||
|
@ -27,6 +28,8 @@ type SymmetricKeyEncrypted struct {
|
|||
encryptedKey []byte
|
||||
}
|
||||
|
||||
const symmetricKeyEncryptedVersion = 4
|
||||
|
||||
func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err os.Error) {
|
||||
// RFC 4880, section 5.3.
|
||||
var buf [2]byte
|
||||
|
@ -34,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err os.Error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if buf[0] != 4 {
|
||||
if buf[0] != symmetricKeyEncryptedVersion {
|
||||
return error.UnsupportedError("SymmetricKeyEncrypted version")
|
||||
}
|
||||
ske.CipherFunc = CipherFunction(buf[1])
|
||||
|
||||
if ske.CipherFunc.keySize() == 0 {
|
||||
if ske.CipherFunc.KeySize() == 0 {
|
||||
return error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
|
||||
}
|
||||
|
||||
|
@ -75,7 +78,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) os.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
key := make([]byte, ske.CipherFunc.keySize())
|
||||
key := make([]byte, ske.CipherFunc.KeySize())
|
||||
ske.s2k(key, passphrase)
|
||||
|
||||
if len(ske.encryptedKey) == 0 {
|
||||
|
@ -100,3 +103,60 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) os.Error {
|
|||
ske.Encrypted = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// SerializeSymmetricKeyEncrypted serializes a symmetric key packet to w. The
|
||||
// packet contains a random session key, encrypted by a key derived from the
|
||||
// given passphrase. The session key is returned and must be passed to
|
||||
// SerializeSymmetricallyEncrypted.
|
||||
func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err os.Error) {
|
||||
keySize := cipherFunc.KeySize()
|
||||
if keySize == 0 {
|
||||
return nil, error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
|
||||
}
|
||||
|
||||
s2kBuf := new(bytes.Buffer)
|
||||
keyEncryptingKey := make([]byte, keySize)
|
||||
// s2k.Serialize salts and stretches the passphrase, and writes the
|
||||
// resulting key to keyEncryptingKey and the s2k descriptor to s2kBuf.
|
||||
err = s2k.Serialize(s2kBuf, keyEncryptingKey, rand, passphrase)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s2kBytes := s2kBuf.Bytes()
|
||||
|
||||
packetLength := 2 /* header */ + len(s2kBytes) + 1 /* cipher type */ + keySize
|
||||
err = serializeHeader(w, packetTypeSymmetricKeyEncrypted, packetLength)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var buf [2]byte
|
||||
buf[0] = symmetricKeyEncryptedVersion
|
||||
buf[1] = byte(cipherFunc)
|
||||
_, err = w.Write(buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = w.Write(s2kBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sessionKey := make([]byte, keySize)
|
||||
_, err = io.ReadFull(rand, sessionKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
iv := make([]byte, cipherFunc.blockSize())
|
||||
c := cipher.NewCFBEncrypter(cipherFunc.new(keyEncryptingKey), iv)
|
||||
encryptedCipherAndKey := make([]byte, keySize+1)
|
||||
c.XORKeyStream(encryptedCipherAndKey, buf[1:])
|
||||
c.XORKeyStream(encryptedCipherAndKey[1:], sessionKey)
|
||||
_, err = w.Write(encryptedCipherAndKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
key = sessionKey
|
||||
return
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package packet
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -60,3 +61,41 @@ func TestSymmetricKeyEncrypted(t *testing.T) {
|
|||
|
||||
const symmetricallyEncryptedHex = "8c0d04030302371a0b38d884f02060c91cf97c9973b8e58e028e9501708ccfe618fb92afef7fa2d80ddadd93cf"
|
||||
const symmetricallyEncryptedContentsHex = "cb1062004d14c4df636f6e74656e74732e0a"
|
||||
|
||||
func TestSerializeSymmetricKeyEncrypted(t *testing.T) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
passphrase := []byte("testing")
|
||||
cipherFunc := CipherAES128
|
||||
|
||||
key, err := SerializeSymmetricKeyEncrypted(buf, rand.Reader, passphrase, cipherFunc)
|
||||
if err != nil {
|
||||
t.Errorf("failed to serialize: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
p, err := Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to reparse: %s", err)
|
||||
return
|
||||
}
|
||||
ske, ok := p.(*SymmetricKeyEncrypted)
|
||||
if !ok {
|
||||
t.Errorf("parsed a different packet type: %#v", p)
|
||||
return
|
||||
}
|
||||
|
||||
if !ske.Encrypted {
|
||||
t.Errorf("SKE not encrypted but should be")
|
||||
}
|
||||
if ske.CipherFunc != cipherFunc {
|
||||
t.Errorf("SKE cipher function is %d (expected %d)", ske.CipherFunc, cipherFunc)
|
||||
}
|
||||
err = ske.Decrypt(passphrase)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decrypt reparsed SKE: %s", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(key, ske.Key) {
|
||||
t.Errorf("keys don't match after Decrpyt: %x (original) vs %x (parsed)", key, ske.Key)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ package packet
|
|||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/openpgp/error"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/subtle"
|
||||
"hash"
|
||||
|
@ -24,6 +25,8 @@ type SymmetricallyEncrypted struct {
|
|||
prefix []byte
|
||||
}
|
||||
|
||||
const symmetricallyEncryptedVersion = 1
|
||||
|
||||
func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
|
||||
if se.MDC {
|
||||
// See RFC 4880, section 5.13.
|
||||
|
@ -32,7 +35,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if buf[0] != 1 {
|
||||
if buf[0] != symmetricallyEncryptedVersion {
|
||||
return error.UnsupportedError("unknown SymmetricallyEncrypted version")
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +47,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
|
|||
// packet can be read. An incorrect key can, with high probability, be detected
|
||||
// immediately and this will result in a KeyIncorrect error being returned.
|
||||
func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, os.Error) {
|
||||
keySize := c.keySize()
|
||||
keySize := c.KeySize()
|
||||
if keySize == 0 {
|
||||
return nil, error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
|
||||
}
|
||||
|
@ -174,6 +177,9 @@ func (ser *seMDCReader) Read(buf []byte) (n int, err os.Error) {
|
|||
return
|
||||
}
|
||||
|
||||
// This is a new-format packet tag byte for a type 19 (MDC) packet.
|
||||
const mdcPacketTagByte = byte(0x80) | 0x40 | 19
|
||||
|
||||
func (ser *seMDCReader) Close() os.Error {
|
||||
if ser.error {
|
||||
return error.SignatureError("error during reading")
|
||||
|
@ -191,16 +197,95 @@ func (ser *seMDCReader) Close() os.Error {
|
|||
}
|
||||
}
|
||||
|
||||
// This is a new-format packet tag byte for a type 19 (MDC) packet.
|
||||
const mdcPacketTagByte = byte(0x80) | 0x40 | 19
|
||||
if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
|
||||
return error.SignatureError("MDC packet not found")
|
||||
}
|
||||
ser.h.Write(ser.trailer[:2])
|
||||
|
||||
final := ser.h.Sum()
|
||||
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) == 1 {
|
||||
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
|
||||
return error.SignatureError("hash mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// An seMDCWriter writes through to an io.WriteCloser while maintains a running
|
||||
// hash of the data written. On close, it emits an MDC packet containing the
|
||||
// running hash.
|
||||
type seMDCWriter struct {
|
||||
w io.WriteCloser
|
||||
h hash.Hash
|
||||
}
|
||||
|
||||
func (w *seMDCWriter) Write(buf []byte) (n int, err os.Error) {
|
||||
w.h.Write(buf)
|
||||
return w.w.Write(buf)
|
||||
}
|
||||
|
||||
func (w *seMDCWriter) Close() (err os.Error) {
|
||||
var buf [mdcTrailerSize]byte
|
||||
|
||||
buf[0] = mdcPacketTagByte
|
||||
buf[1] = sha1.Size
|
||||
w.h.Write(buf[:2])
|
||||
digest := w.h.Sum()
|
||||
copy(buf[2:], digest)
|
||||
|
||||
_, err = w.w.Write(buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return w.w.Close()
|
||||
}
|
||||
|
||||
// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
|
||||
type noOpCloser struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (c noOpCloser) Write(data []byte) (n int, err os.Error) {
|
||||
return c.w.Write(data)
|
||||
}
|
||||
|
||||
func (c noOpCloser) Close() os.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
|
||||
// to w and returns a WriteCloser to which the to-be-encrypted packets can be
|
||||
// written.
|
||||
func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) (contents io.WriteCloser, err os.Error) {
|
||||
if c.KeySize() != len(key) {
|
||||
return nil, error.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
|
||||
}
|
||||
writeCloser := noOpCloser{w}
|
||||
ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = ciphertext.Write([]byte{symmetricallyEncryptedVersion})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
block := c.new(key)
|
||||
blockSize := block.BlockSize()
|
||||
iv := make([]byte, blockSize)
|
||||
_, err = rand.Reader.Read(iv)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s, prefix := cipher.NewOCFBEncrypter(block, iv, cipher.OCFBNoResync)
|
||||
_, err = ciphertext.Write(prefix)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
plaintext := cipher.StreamWriter{S: s, W: ciphertext}
|
||||
|
||||
h := sha1.New()
|
||||
h.Write(iv)
|
||||
h.Write(iv[blockSize-2:])
|
||||
contents = &seMDCWriter{w: plaintext, h: h}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"crypto/openpgp/error"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
@ -76,3 +77,48 @@ func testMDCReader(t *testing.T) {
|
|||
}
|
||||
|
||||
const mdcPlaintextHex = "a302789c3b2d93c4e0eb9aba22283539b3203335af44a134afb800c849cb4c4de10200aff40b45d31432c80cb384299a0655966d6939dfdeed1dddf980"
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
c := CipherAES128
|
||||
key := make([]byte, c.KeySize())
|
||||
|
||||
w, err := SerializeSymmetricallyEncrypted(buf, c, key)
|
||||
if err != nil {
|
||||
t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
contents := []byte("hello world\n")
|
||||
|
||||
w.Write(contents)
|
||||
w.Close()
|
||||
|
||||
p, err := Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("error from Read: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
se, ok := p.(*SymmetricallyEncrypted)
|
||||
if !ok {
|
||||
t.Errorf("didn't read a *SymmetricallyEncrypted")
|
||||
return
|
||||
}
|
||||
|
||||
r, err := se.Decrypt(c, key)
|
||||
if err != nil {
|
||||
t.Errorf("error from Decrypt: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
contentsCopy := bytes.NewBuffer(nil)
|
||||
_, err = io.Copy(contentsCopy, r)
|
||||
if err != nil {
|
||||
t.Errorf("error from io.Copy: %s", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(contentsCopy.Bytes(), contents) {
|
||||
t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue