Source file src/go/types/hilbert_test.go

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package types_test
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/importer"
    13  	"go/parser"
    14  	"go/token"
    15  	"os"
    16  	"testing"
    17  
    18  	. "go/types"
    19  )
    20  
    21  var (
    22  	H   = flag.Int("H", 5, "Hilbert matrix size")
    23  	out = flag.String("out", "", "write generated program to out")
    24  )
    25  
    26  func TestHilbert(t *testing.T) {
    27  	// generate source
    28  	src := program(*H, *out)
    29  	if *out != "" {
    30  		os.WriteFile(*out, src, 0666)
    31  		return
    32  	}
    33  
    34  	// parse source
    35  	fset := token.NewFileSet()
    36  	f, err := parser.ParseFile(fset, "hilbert.go", src, 0)
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  
    41  	// type-check file
    42  	DefPredeclaredTestFuncs() // define assert built-in
    43  	conf := Config{Importer: importer.Default()}
    44  	_, err = conf.Check(f.Name.Name, fset, []*ast.File{f}, nil)
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  }
    49  
    50  func program(n int, out string) []byte {
    51  	var g gen
    52  
    53  	g.p(`// Code generated by: go test -run=Hilbert -H=%d -out=%q. DO NOT EDIT.
    54  
    55  // +`+`build ignore
    56  
    57  // This program tests arbitrary precision constant arithmetic
    58  // by generating the constant elements of a Hilbert matrix H,
    59  // its inverse I, and the product P = H*I. The product should
    60  // be the identity matrix.
    61  package main
    62  
    63  func main() {
    64  	if !ok {
    65  		printProduct()
    66  		return
    67  	}
    68  	println("PASS")
    69  }
    70  
    71  `, n, out)
    72  	g.hilbert(n)
    73  	g.inverse(n)
    74  	g.product(n)
    75  	g.verify(n)
    76  	g.printProduct(n)
    77  	g.binomials(2*n - 1)
    78  	g.factorials(2*n - 1)
    79  
    80  	return g.Bytes()
    81  }
    82  
    83  type gen struct {
    84  	bytes.Buffer
    85  }
    86  
    87  func (g *gen) p(format string, args ...any) {
    88  	fmt.Fprintf(&g.Buffer, format, args...)
    89  }
    90  
    91  func (g *gen) hilbert(n int) {
    92  	g.p(`// Hilbert matrix, n = %d
    93  const (
    94  `, n)
    95  	for i := 0; i < n; i++ {
    96  		g.p("\t")
    97  		for j := 0; j < n; j++ {
    98  			if j > 0 {
    99  				g.p(", ")
   100  			}
   101  			g.p("h%d_%d", i, j)
   102  		}
   103  		if i == 0 {
   104  			g.p(" = ")
   105  			for j := 0; j < n; j++ {
   106  				if j > 0 {
   107  					g.p(", ")
   108  				}
   109  				g.p("1.0/(iota + %d)", j+1)
   110  			}
   111  		}
   112  		g.p("\n")
   113  	}
   114  	g.p(")\n\n")
   115  }
   116  
   117  func (g *gen) inverse(n int) {
   118  	g.p(`// Inverse Hilbert matrix
   119  const (
   120  `)
   121  	for i := 0; i < n; i++ {
   122  		for j := 0; j < n; j++ {
   123  			s := "+"
   124  			if (i+j)&1 != 0 {
   125  				s = "-"
   126  			}
   127  			g.p("\ti%d_%d = %s%d * b%d_%d * b%d_%d * b%d_%d * b%d_%d\n",
   128  				i, j, s, i+j+1, n+i, n-j-1, n+j, n-i-1, i+j, i, i+j, i)
   129  		}
   130  		g.p("\n")
   131  	}
   132  	g.p(")\n\n")
   133  }
   134  
   135  func (g *gen) product(n int) {
   136  	g.p(`// Product matrix
   137  const (
   138  `)
   139  	for i := 0; i < n; i++ {
   140  		for j := 0; j < n; j++ {
   141  			g.p("\tp%d_%d = ", i, j)
   142  			for k := 0; k < n; k++ {
   143  				if k > 0 {
   144  					g.p(" + ")
   145  				}
   146  				g.p("h%d_%d*i%d_%d", i, k, k, j)
   147  			}
   148  			g.p("\n")
   149  		}
   150  		g.p("\n")
   151  	}
   152  	g.p(")\n\n")
   153  }
   154  
   155  func (g *gen) verify(n int) {
   156  	g.p(`// Verify that product is the identity matrix
   157  const ok =
   158  `)
   159  	for i := 0; i < n; i++ {
   160  		for j := 0; j < n; j++ {
   161  			if j == 0 {
   162  				g.p("\t")
   163  			} else {
   164  				g.p(" && ")
   165  			}
   166  			v := 0
   167  			if i == j {
   168  				v = 1
   169  			}
   170  			g.p("p%d_%d == %d", i, j, v)
   171  		}
   172  		g.p(" &&\n")
   173  	}
   174  	g.p("\ttrue\n\n")
   175  
   176  	// verify ok at type-check time
   177  	if *out == "" {
   178  		g.p("const _ = assert(ok)\n\n")
   179  	}
   180  }
   181  
   182  func (g *gen) printProduct(n int) {
   183  	g.p("func printProduct() {\n")
   184  	for i := 0; i < n; i++ {
   185  		g.p("\tprintln(")
   186  		for j := 0; j < n; j++ {
   187  			if j > 0 {
   188  				g.p(", ")
   189  			}
   190  			g.p("p%d_%d", i, j)
   191  		}
   192  		g.p(")\n")
   193  	}
   194  	g.p("}\n\n")
   195  }
   196  
   197  func (g *gen) binomials(n int) {
   198  	g.p(`// Binomials
   199  const (
   200  `)
   201  	for j := 0; j <= n; j++ {
   202  		if j > 0 {
   203  			g.p("\n")
   204  		}
   205  		for k := 0; k <= j; k++ {
   206  			g.p("\tb%d_%d = f%d / (f%d*f%d)\n", j, k, j, k, j-k)
   207  		}
   208  	}
   209  	g.p(")\n\n")
   210  }
   211  
   212  func (g *gen) factorials(n int) {
   213  	g.p(`// Factorials
   214  const (
   215  	f0 = 1
   216  	f1 = 1
   217  `)
   218  	for i := 2; i <= n; i++ {
   219  		g.p("\tf%d = f%d * %d\n", i, i-1, i)
   220  	}
   221  	g.p(")\n\n")
   222  }
   223  

View as plain text