1
2
3
4
5 package ssa
6
7 import (
8 "fmt"
9 "os"
10 )
11
12
13 var debugPoset = false
14
15 const uintSize = 32 << (^uint(0) >> 63)
16
17
18 type bitset []uint
19
20 func newBitset(n int) bitset {
21 return make(bitset, (n+uintSize-1)/uintSize)
22 }
23
24 func (bs bitset) Reset() {
25 for i := range bs {
26 bs[i] = 0
27 }
28 }
29
30 func (bs bitset) Set(idx uint32) {
31 bs[idx/uintSize] |= 1 << (idx % uintSize)
32 }
33
34 func (bs bitset) Clear(idx uint32) {
35 bs[idx/uintSize] &^= 1 << (idx % uintSize)
36 }
37
38 func (bs bitset) Test(idx uint32) bool {
39 return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
40 }
41
42 type undoType uint8
43
44 const (
45 undoInvalid undoType = iota
46 undoCheckpoint
47 undoSetChl
48 undoSetChr
49 undoNonEqual
50 undoNewNode
51 undoNewConstant
52 undoAliasNode
53 undoNewRoot
54 undoChangeRoot
55 undoMergeRoot
56 )
57
58
59
60
61
62 type posetUndo struct {
63 typ undoType
64 idx uint32
65 ID ID
66 edge posetEdge
67 }
68
69 const (
70
71 posetFlagUnsigned = 1 << iota
72 )
73
74
75
76 type posetEdge uint32
77
78 func newedge(t uint32, strict bool) posetEdge {
79 s := uint32(0)
80 if strict {
81 s = 1
82 }
83 return posetEdge(t<<1 | s)
84 }
85 func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
86 func (e posetEdge) Strict() bool { return uint32(e)&1 != 0 }
87 func (e posetEdge) String() string {
88 s := fmt.Sprint(e.Target())
89 if e.Strict() {
90 s += "*"
91 }
92 return s
93 }
94
95
96 type posetNode struct {
97 l, r posetEdge
98 }
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149 type poset struct {
150 lastidx uint32
151 flags uint8
152 values map[ID]uint32
153 constants map[int64]uint32
154 nodes []posetNode
155 roots []uint32
156 noneq map[uint32]bitset
157 undo []posetUndo
158 }
159
160 func newPoset() *poset {
161 return &poset{
162 values: make(map[ID]uint32),
163 constants: make(map[int64]uint32, 8),
164 nodes: make([]posetNode, 1, 16),
165 roots: make([]uint32, 0, 4),
166 noneq: make(map[uint32]bitset),
167 undo: make([]posetUndo, 0, 4),
168 }
169 }
170
171 func (po *poset) SetUnsigned(uns bool) {
172 if uns {
173 po.flags |= posetFlagUnsigned
174 } else {
175 po.flags &^= posetFlagUnsigned
176 }
177 }
178
179
180 func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
181 func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
182 func (po *poset) chl(i uint32) uint32 { return po.nodes[i].l.Target() }
183 func (po *poset) chr(i uint32) uint32 { return po.nodes[i].r.Target() }
184 func (po *poset) children(i uint32) (posetEdge, posetEdge) {
185 return po.nodes[i].l, po.nodes[i].r
186 }
187
188
189
190 func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
191 po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
192 }
193
194
195 func (po *poset) upushnew(id ID, idx uint32) {
196 po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
197 }
198
199
200 func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
201 po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
202 }
203
204
205 func (po *poset) upushalias(id ID, i2 uint32) {
206 po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
207 }
208
209
210 func (po *poset) upushconst(idx uint32, old uint32) {
211 po.undo = append(po.undo, posetUndo{typ: undoNewConstant, idx: idx, ID: ID(old)})
212 }
213
214
215 func (po *poset) addchild(i1, i2 uint32, strict bool) {
216 i1l, i1r := po.children(i1)
217 e2 := newedge(i2, strict)
218
219 if i1l == 0 {
220 po.setchl(i1, e2)
221 po.upush(undoSetChl, i1, 0)
222 } else if i1r == 0 {
223 po.setchr(i1, e2)
224 po.upush(undoSetChr, i1, 0)
225 } else {
226
227
228
229
230
231
232
233
234
235
236
237
238 extra := po.newnode(nil)
239 if (i1^i2)&1 != 0 {
240 po.setchl(extra, i1r)
241 po.setchr(extra, e2)
242 po.setchr(i1, newedge(extra, false))
243 po.upush(undoSetChr, i1, i1r)
244 } else {
245 po.setchl(extra, i1l)
246 po.setchr(extra, e2)
247 po.setchl(i1, newedge(extra, false))
248 po.upush(undoSetChl, i1, i1l)
249 }
250 }
251 }
252
253
254
255 func (po *poset) newnode(n *Value) uint32 {
256 i := po.lastidx + 1
257 po.lastidx++
258 po.nodes = append(po.nodes, posetNode{})
259 if n != nil {
260 if po.values[n.ID] != 0 {
261 panic("newnode for Value already inserted")
262 }
263 po.values[n.ID] = i
264 po.upushnew(n.ID, i)
265 } else {
266 po.upushnew(0, i)
267 }
268 return i
269 }
270
271
272
273 func (po *poset) lookup(n *Value) (uint32, bool) {
274 i, f := po.values[n.ID]
275 if !f && n.isGenericIntConst() {
276 po.newconst(n)
277 i, f = po.values[n.ID]
278 }
279 return i, f
280 }
281
282
283
284
285 func (po *poset) newconst(n *Value) {
286 if !n.isGenericIntConst() {
287 panic("newconst on non-constant")
288 }
289
290
291
292 val := n.AuxInt
293 if po.flags&posetFlagUnsigned != 0 {
294 val = int64(n.AuxUnsigned())
295 }
296 if c, found := po.constants[val]; found {
297 po.values[n.ID] = c
298 po.upushalias(n.ID, 0)
299 return
300 }
301
302
303 i := po.newnode(n)
304
305
306
307
308
309
310 if len(po.constants) == 0 {
311 idx := len(po.roots)
312 po.roots = append(po.roots, i)
313 po.roots[0], po.roots[idx] = po.roots[idx], po.roots[0]
314 po.upush(undoNewRoot, i, 0)
315 po.constants[val] = i
316 po.upushconst(i, 0)
317 return
318 }
319
320
321
322
323
324
325 var lowerptr, higherptr uint32
326
327 if po.flags&posetFlagUnsigned != 0 {
328 var lower, higher uint64
329 val1 := n.AuxUnsigned()
330 for val2, ptr := range po.constants {
331 val2 := uint64(val2)
332 if val1 == val2 {
333 panic("unreachable")
334 }
335 if val2 < val1 && (lowerptr == 0 || val2 > lower) {
336 lower = val2
337 lowerptr = ptr
338 } else if val2 > val1 && (higherptr == 0 || val2 < higher) {
339 higher = val2
340 higherptr = ptr
341 }
342 }
343 } else {
344 var lower, higher int64
345 val1 := n.AuxInt
346 for val2, ptr := range po.constants {
347 if val1 == val2 {
348 panic("unreachable")
349 }
350 if val2 < val1 && (lowerptr == 0 || val2 > lower) {
351 lower = val2
352 lowerptr = ptr
353 } else if val2 > val1 && (higherptr == 0 || val2 < higher) {
354 higher = val2
355 higherptr = ptr
356 }
357 }
358 }
359
360 if lowerptr == 0 && higherptr == 0 {
361
362
363 panic("no constant found")
364 }
365
366
367
368
369
370
371 switch {
372 case lowerptr != 0 && higherptr != 0:
373
374 po.addchild(lowerptr, i, true)
375 po.addchild(i, higherptr, true)
376
377 case lowerptr != 0:
378
379 po.addchild(lowerptr, i, true)
380
381 case higherptr != 0:
382
383
384
385
386
387
388
389
390
391
392
393 i2 := higherptr
394 r2 := po.findroot(i2)
395 if r2 != po.roots[0] {
396 panic("constant not in root #0")
397 }
398 extra := po.newnode(nil)
399 po.changeroot(r2, extra)
400 po.upush(undoChangeRoot, extra, newedge(r2, false))
401 po.addchild(extra, r2, false)
402 po.addchild(extra, i, false)
403 po.addchild(i, i2, true)
404 }
405
406 po.constants[val] = i
407 po.upushconst(i, 0)
408 }
409
410
411
412 func (po *poset) aliasnewnode(n1, n2 *Value) {
413 i1, i2 := po.values[n1.ID], po.values[n2.ID]
414 if i1 == 0 || i2 != 0 {
415 panic("aliasnewnode invalid arguments")
416 }
417
418 po.values[n2.ID] = i1
419 po.upushalias(n2.ID, 0)
420 }
421
422
423
424
425
426
427 func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
428 i1 := po.values[n1.ID]
429 if i1 == 0 {
430 panic("aliasnode for non-existing node")
431 }
432 if i2s.Test(i1) {
433 panic("aliasnode i2s contains n1 node")
434 }
435
436
437 for idx, n := range po.nodes {
438
439 if uint32(idx) == i1 {
440 continue
441 }
442 l, r := n.l, n.r
443
444
445 if i2s.Test(l.Target()) {
446 po.setchl(uint32(idx), newedge(i1, l.Strict()))
447 po.upush(undoSetChl, uint32(idx), l)
448 }
449 if i2s.Test(r.Target()) {
450 po.setchr(uint32(idx), newedge(i1, r.Strict()))
451 po.upush(undoSetChr, uint32(idx), r)
452 }
453
454
455
456 if i2s.Test(uint32(idx)) {
457 if l != 0 && !i2s.Test(l.Target()) {
458 po.addchild(i1, l.Target(), l.Strict())
459 }
460 if r != 0 && !i2s.Test(r.Target()) {
461 po.addchild(i1, r.Target(), r.Strict())
462 }
463 po.setchl(uint32(idx), 0)
464 po.setchr(uint32(idx), 0)
465 po.upush(undoSetChl, uint32(idx), l)
466 po.upush(undoSetChr, uint32(idx), r)
467 }
468 }
469
470
471
472 for k, v := range po.values {
473 if i2s.Test(v) {
474 po.values[k] = i1
475 po.upushalias(k, v)
476 }
477 }
478
479
480
481 for val, idx := range po.constants {
482 if i2s.Test(idx) {
483 po.constants[val] = i1
484 po.upushconst(i1, idx)
485 }
486 }
487 }
488
489 func (po *poset) isroot(r uint32) bool {
490 for i := range po.roots {
491 if po.roots[i] == r {
492 return true
493 }
494 }
495 return false
496 }
497
498 func (po *poset) changeroot(oldr, newr uint32) {
499 for i := range po.roots {
500 if po.roots[i] == oldr {
501 po.roots[i] = newr
502 return
503 }
504 }
505 panic("changeroot on non-root")
506 }
507
508 func (po *poset) removeroot(r uint32) {
509 for i := range po.roots {
510 if po.roots[i] == r {
511 po.roots = append(po.roots[:i], po.roots[i+1:]...)
512 return
513 }
514 }
515 panic("removeroot on non-root")
516 }
517
518
519
520
521
522
523
524
525
526 func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
527 closed := newBitset(int(po.lastidx + 1))
528 open := make([]uint32, 1, 64)
529 open[0] = r
530
531 if strict {
532
533
534
535 next := make([]uint32, 0, 64)
536
537 for len(open) > 0 {
538 i := open[len(open)-1]
539 open = open[:len(open)-1]
540
541
542
543
544
545 if !closed.Test(i) {
546 closed.Set(i)
547
548 l, r := po.children(i)
549 if l != 0 {
550 if l.Strict() {
551 next = append(next, l.Target())
552 } else {
553 open = append(open, l.Target())
554 }
555 }
556 if r != 0 {
557 if r.Strict() {
558 next = append(next, r.Target())
559 } else {
560 open = append(open, r.Target())
561 }
562 }
563 }
564 }
565 open = next
566 closed.Reset()
567 }
568
569 for len(open) > 0 {
570 i := open[len(open)-1]
571 open = open[:len(open)-1]
572
573 if !closed.Test(i) {
574 if f(i) {
575 return true
576 }
577 closed.Set(i)
578 l, r := po.children(i)
579 if l != 0 {
580 open = append(open, l.Target())
581 }
582 if r != 0 {
583 open = append(open, r.Target())
584 }
585 }
586 }
587 return false
588 }
589
590
591
592
593
594 func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
595 return po.dfs(i1, strict, func(n uint32) bool {
596 return n == i2
597 })
598 }
599
600
601
602
603 func (po *poset) findroot(i uint32) uint32 {
604
605
606
607 for _, r := range po.roots {
608 if po.reaches(r, i, false) {
609 return r
610 }
611 }
612 panic("findroot didn't find any root")
613 }
614
615
616 func (po *poset) mergeroot(r1, r2 uint32) uint32 {
617
618
619
620 if r2 == po.roots[0] {
621 r1, r2 = r2, r1
622 }
623 r := po.newnode(nil)
624 po.setchl(r, newedge(r1, false))
625 po.setchr(r, newedge(r2, false))
626 po.changeroot(r1, r)
627 po.removeroot(r2)
628 po.upush(undoMergeRoot, r, 0)
629 return r
630 }
631
632
633
634
635
636 func (po *poset) collapsepath(n1, n2 *Value) bool {
637 i1, i2 := po.values[n1.ID], po.values[n2.ID]
638 if po.reaches(i1, i2, true) {
639 return false
640 }
641
642
643 paths := po.findpaths(i1, i2)
644
645
646 paths.Clear(i1)
647 po.aliasnodes(n1, paths)
648 return true
649 }
650
651
652
653
654
655
656
657 func (po *poset) findpaths(cur, dst uint32) bitset {
658 seen := newBitset(int(po.lastidx + 1))
659 path := newBitset(int(po.lastidx + 1))
660 path.Set(dst)
661 po.findpaths1(cur, dst, seen, path)
662 return path
663 }
664
665 func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
666 if cur == dst {
667 return
668 }
669 seen.Set(cur)
670 l, r := po.chl(cur), po.chr(cur)
671 if !seen.Test(l) {
672 po.findpaths1(l, dst, seen, path)
673 }
674 if !seen.Test(r) {
675 po.findpaths1(r, dst, seen, path)
676 }
677 if path.Test(l) || path.Test(r) {
678 path.Set(cur)
679 }
680 }
681
682
683 func (po *poset) isnoneq(i1, i2 uint32) bool {
684 if i1 == i2 {
685 return false
686 }
687 if i1 < i2 {
688 i1, i2 = i2, i1
689 }
690
691
692 if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
693 return true
694 }
695 return false
696 }
697
698
699 func (po *poset) setnoneq(n1, n2 *Value) {
700 i1, f1 := po.lookup(n1)
701 i2, f2 := po.lookup(n2)
702
703
704
705
706 if !f1 {
707 i1 = po.newnode(n1)
708 po.roots = append(po.roots, i1)
709 po.upush(undoNewRoot, i1, 0)
710 }
711 if !f2 {
712 i2 = po.newnode(n2)
713 po.roots = append(po.roots, i2)
714 po.upush(undoNewRoot, i2, 0)
715 }
716
717 if i1 == i2 {
718 panic("setnoneq on same node")
719 }
720 if i1 < i2 {
721 i1, i2 = i2, i1
722 }
723 bs := po.noneq[i1]
724 if bs == nil {
725
726
727
728
729 bs = newBitset(int(i1))
730 po.noneq[i1] = bs
731 } else if bs.Test(i2) {
732
733 return
734 }
735 bs.Set(i2)
736 po.upushneq(i1, i2)
737 }
738
739
740
741 func (po *poset) CheckIntegrity() {
742
743 constants := newBitset(int(po.lastidx + 1))
744 for _, c := range po.constants {
745 constants.Set(c)
746 }
747
748
749
750 seen := newBitset(int(po.lastidx + 1))
751 for ridx, r := range po.roots {
752 if r == 0 {
753 panic("empty root")
754 }
755
756 po.dfs(r, false, func(i uint32) bool {
757 if seen.Test(i) {
758 panic("duplicate node")
759 }
760 seen.Set(i)
761 if constants.Test(i) {
762 if ridx != 0 {
763 panic("constants not in the first DAG")
764 }
765 }
766 return false
767 })
768 }
769
770
771 for id, idx := range po.values {
772 if !seen.Test(idx) {
773 panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
774 }
775 }
776
777
778 for i, n := range po.nodes {
779 if n.l|n.r != 0 {
780 if !seen.Test(uint32(i)) {
781 panic(fmt.Errorf("children of unknown node %d->%v", i, n))
782 }
783 if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
784 panic(fmt.Errorf("self-loop on node %d", i))
785 }
786 }
787 }
788 }
789
790
791
792
793 func (po *poset) CheckEmpty() error {
794 if len(po.nodes) != 1 {
795 return fmt.Errorf("non-empty nodes list: %v", po.nodes)
796 }
797 if len(po.values) != 0 {
798 return fmt.Errorf("non-empty value map: %v", po.values)
799 }
800 if len(po.roots) != 0 {
801 return fmt.Errorf("non-empty root list: %v", po.roots)
802 }
803 if len(po.constants) != 0 {
804 return fmt.Errorf("non-empty constants: %v", po.constants)
805 }
806 if len(po.undo) != 0 {
807 return fmt.Errorf("non-empty undo list: %v", po.undo)
808 }
809 if po.lastidx != 0 {
810 return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
811 }
812 for _, bs := range po.noneq {
813 for _, x := range bs {
814 if x != 0 {
815 return fmt.Errorf("non-empty noneq map")
816 }
817 }
818 }
819 return nil
820 }
821
822
823 func (po *poset) DotDump(fn string, title string) error {
824 f, err := os.Create(fn)
825 if err != nil {
826 return err
827 }
828 defer f.Close()
829
830
831 names := make(map[uint32]string)
832 for id, i := range po.values {
833 s := names[i]
834 if s == "" {
835 s = fmt.Sprintf("v%d", id)
836 } else {
837 s += fmt.Sprintf(", v%d", id)
838 }
839 names[i] = s
840 }
841
842
843 consts := make(map[uint32]int64)
844 for val, idx := range po.constants {
845 consts[idx] = val
846 }
847
848 fmt.Fprintf(f, "digraph poset {\n")
849 fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
850 for ridx, r := range po.roots {
851 fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
852 po.dfs(r, false, func(i uint32) bool {
853 if val, ok := consts[i]; ok {
854
855 var vals string
856 if po.flags&posetFlagUnsigned != 0 {
857 vals = fmt.Sprint(uint64(val))
858 } else {
859 vals = fmt.Sprint(int64(val))
860 }
861 fmt.Fprintf(f, "\t\tnode%d [shape=box style=filled fillcolor=cadetblue1 label=<%s <font point-size=\"6\">%s [%d]</font>>]\n",
862 i, vals, names[i], i)
863 } else {
864
865 fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
866 }
867 chl, chr := po.children(i)
868 for _, ch := range []posetEdge{chl, chr} {
869 if ch != 0 {
870 if ch.Strict() {
871 fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
872 } else {
873 fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
874 }
875 }
876 }
877 return false
878 })
879 fmt.Fprintf(f, "\t}\n")
880 }
881 fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
882 fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
883 fmt.Fprintf(f, "\tlabel=%q\n", title)
884 fmt.Fprintf(f, "}\n")
885 return nil
886 }
887
888
889
890
891
892 func (po *poset) Ordered(n1, n2 *Value) bool {
893 if debugPoset {
894 defer po.CheckIntegrity()
895 }
896 if n1.ID == n2.ID {
897 panic("should not call Ordered with n1==n2")
898 }
899
900 i1, f1 := po.lookup(n1)
901 i2, f2 := po.lookup(n2)
902 if !f1 || !f2 {
903 return false
904 }
905
906 return i1 != i2 && po.reaches(i1, i2, true)
907 }
908
909
910
911
912
913 func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
914 if debugPoset {
915 defer po.CheckIntegrity()
916 }
917 if n1.ID == n2.ID {
918 panic("should not call Ordered with n1==n2")
919 }
920
921 i1, f1 := po.lookup(n1)
922 i2, f2 := po.lookup(n2)
923 if !f1 || !f2 {
924 return false
925 }
926
927 return i1 == i2 || po.reaches(i1, i2, false)
928 }
929
930
931
932
933
934 func (po *poset) Equal(n1, n2 *Value) bool {
935 if debugPoset {
936 defer po.CheckIntegrity()
937 }
938 if n1.ID == n2.ID {
939 panic("should not call Equal with n1==n2")
940 }
941
942 i1, f1 := po.lookup(n1)
943 i2, f2 := po.lookup(n2)
944 return f1 && f2 && i1 == i2
945 }
946
947
948
949
950
951
952 func (po *poset) NonEqual(n1, n2 *Value) bool {
953 if debugPoset {
954 defer po.CheckIntegrity()
955 }
956 if n1.ID == n2.ID {
957 panic("should not call NonEqual with n1==n2")
958 }
959
960
961
962 i1, f1 := po.lookup(n1)
963 i2, f2 := po.lookup(n2)
964 if !f1 || !f2 {
965 return false
966 }
967
968
969 if po.isnoneq(i1, i2) {
970 return true
971 }
972
973
974 if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
975 return true
976 }
977
978 return false
979 }
980
981
982
983
984 func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
985 i1, f1 := po.lookup(n1)
986 i2, f2 := po.lookup(n2)
987
988 switch {
989 case !f1 && !f2:
990
991
992
993 i1, i2 = po.newnode(n1), po.newnode(n2)
994 po.roots = append(po.roots, i1)
995 po.upush(undoNewRoot, i1, 0)
996 po.addchild(i1, i2, strict)
997
998 case f1 && !f2:
999
1000
1001 i2 = po.newnode(n2)
1002 po.addchild(i1, i2, strict)
1003
1004 case !f1 && f2:
1005
1006
1007
1008 i1 = po.newnode(n1)
1009
1010 if po.isroot(i2) {
1011 po.changeroot(i2, i1)
1012 po.upush(undoChangeRoot, i1, newedge(i2, strict))
1013 po.addchild(i1, i2, strict)
1014 return true
1015 }
1016
1017
1018
1019 r := po.findroot(i2)
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029 extra := po.newnode(nil)
1030 po.changeroot(r, extra)
1031 po.upush(undoChangeRoot, extra, newedge(r, false))
1032 po.addchild(extra, r, false)
1033 po.addchild(extra, i1, false)
1034 po.addchild(i1, i2, strict)
1035
1036 case f1 && f2:
1037
1038
1039 if i1 == i2 {
1040 return !strict
1041 }
1042
1043
1044
1045 if !strict && po.isnoneq(i1, i2) {
1046 strict = true
1047 }
1048
1049
1050
1051
1052
1053 if po.reaches(i1, i2, false) {
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064 if strict && !po.reaches(i1, i2, true) {
1065 po.addchild(i1, i2, true)
1066 return true
1067 }
1068
1069
1070 return true
1071 }
1072
1073
1074 if po.reaches(i2, i1, false) {
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084 if strict {
1085
1086 return false
1087 }
1088
1089
1090
1091 return po.collapsepath(n2, n1)
1092 }
1093
1094
1095
1096
1097 r1, r2 := po.findroot(i1), po.findroot(i2)
1098 if r1 != r2 {
1099
1100 po.mergeroot(r1, r2)
1101 }
1102
1103
1104 po.addchild(i1, i2, strict)
1105 }
1106
1107 return true
1108 }
1109
1110
1111
1112 func (po *poset) SetOrder(n1, n2 *Value) bool {
1113 if debugPoset {
1114 defer po.CheckIntegrity()
1115 }
1116 if n1.ID == n2.ID {
1117 panic("should not call SetOrder with n1==n2")
1118 }
1119 return po.setOrder(n1, n2, true)
1120 }
1121
1122
1123
1124 func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
1125 if debugPoset {
1126 defer po.CheckIntegrity()
1127 }
1128 if n1.ID == n2.ID {
1129 panic("should not call SetOrder with n1==n2")
1130 }
1131 return po.setOrder(n1, n2, false)
1132 }
1133
1134
1135
1136
1137 func (po *poset) SetEqual(n1, n2 *Value) bool {
1138 if debugPoset {
1139 defer po.CheckIntegrity()
1140 }
1141 if n1.ID == n2.ID {
1142 panic("should not call Add with n1==n2")
1143 }
1144
1145 i1, f1 := po.lookup(n1)
1146 i2, f2 := po.lookup(n2)
1147
1148 switch {
1149 case !f1 && !f2:
1150 i1 = po.newnode(n1)
1151 po.roots = append(po.roots, i1)
1152 po.upush(undoNewRoot, i1, 0)
1153 po.aliasnewnode(n1, n2)
1154 case f1 && !f2:
1155 po.aliasnewnode(n1, n2)
1156 case !f1 && f2:
1157 po.aliasnewnode(n2, n1)
1158 case f1 && f2:
1159 if i1 == i2 {
1160
1161 return true
1162 }
1163
1164
1165 if po.isnoneq(i1, i2) {
1166 return false
1167 }
1168
1169
1170
1171 if po.reaches(i1, i2, false) {
1172 return po.collapsepath(n1, n2)
1173 }
1174 if po.reaches(i2, i1, false) {
1175 return po.collapsepath(n2, n1)
1176 }
1177
1178 r1 := po.findroot(i1)
1179 r2 := po.findroot(i2)
1180 if r1 != r2 {
1181
1182 po.mergeroot(r1, r2)
1183 }
1184
1185
1186
1187 i2s := newBitset(int(po.lastidx) + 1)
1188 i2s.Set(i2)
1189 po.aliasnodes(n1, i2s)
1190 }
1191 return true
1192 }
1193
1194
1195
1196
1197 func (po *poset) SetNonEqual(n1, n2 *Value) bool {
1198 if debugPoset {
1199 defer po.CheckIntegrity()
1200 }
1201 if n1.ID == n2.ID {
1202 panic("should not call SetNonEqual with n1==n2")
1203 }
1204
1205
1206 i1, f1 := po.lookup(n1)
1207 i2, f2 := po.lookup(n2)
1208
1209
1210
1211 if !f1 || !f2 {
1212 po.setnoneq(n1, n2)
1213 return true
1214 }
1215
1216
1217 if po.isnoneq(i1, i2) {
1218 return true
1219 }
1220
1221
1222 if po.Equal(n1, n2) {
1223 return false
1224 }
1225
1226
1227 po.setnoneq(n1, n2)
1228
1229
1230
1231
1232
1233 if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
1234 po.addchild(i1, i2, true)
1235 }
1236 if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
1237 po.addchild(i2, i1, true)
1238 }
1239
1240 return true
1241 }
1242
1243
1244
1245
1246 func (po *poset) Checkpoint() {
1247 po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
1248 }
1249
1250
1251
1252
1253
1254 func (po *poset) Undo() {
1255 if len(po.undo) == 0 {
1256 panic("empty undo stack")
1257 }
1258 if debugPoset {
1259 defer po.CheckIntegrity()
1260 }
1261
1262 for len(po.undo) > 0 {
1263 pass := po.undo[len(po.undo)-1]
1264 po.undo = po.undo[:len(po.undo)-1]
1265
1266 switch pass.typ {
1267 case undoCheckpoint:
1268 return
1269
1270 case undoSetChl:
1271 po.setchl(pass.idx, pass.edge)
1272
1273 case undoSetChr:
1274 po.setchr(pass.idx, pass.edge)
1275
1276 case undoNonEqual:
1277 po.noneq[uint32(pass.ID)].Clear(pass.idx)
1278
1279 case undoNewNode:
1280 if pass.idx != po.lastidx {
1281 panic("invalid newnode index")
1282 }
1283 if pass.ID != 0 {
1284 if po.values[pass.ID] != pass.idx {
1285 panic("invalid newnode undo pass")
1286 }
1287 delete(po.values, pass.ID)
1288 }
1289 po.setchl(pass.idx, 0)
1290 po.setchr(pass.idx, 0)
1291 po.nodes = po.nodes[:pass.idx]
1292 po.lastidx--
1293
1294 case undoNewConstant:
1295
1296 var val int64
1297 var i uint32
1298 for val, i = range po.constants {
1299 if i == pass.idx {
1300 break
1301 }
1302 }
1303 if i != pass.idx {
1304 panic("constant not found in undo pass")
1305 }
1306 if pass.ID == 0 {
1307 delete(po.constants, val)
1308 } else {
1309
1310
1311 oldidx := uint32(pass.ID)
1312 po.constants[val] = oldidx
1313 }
1314
1315 case undoAliasNode:
1316 ID, prev := pass.ID, pass.idx
1317 cur := po.values[ID]
1318 if prev == 0 {
1319
1320 delete(po.values, ID)
1321 } else {
1322 if cur == prev {
1323 panic("invalid aliasnode undo pass")
1324 }
1325
1326 po.values[ID] = prev
1327 }
1328
1329 case undoNewRoot:
1330 i := pass.idx
1331 l, r := po.children(i)
1332 if l|r != 0 {
1333 panic("non-empty root in undo newroot")
1334 }
1335 po.removeroot(i)
1336
1337 case undoChangeRoot:
1338 i := pass.idx
1339 l, r := po.children(i)
1340 if l|r != 0 {
1341 panic("non-empty root in undo changeroot")
1342 }
1343 po.changeroot(i, pass.edge.Target())
1344
1345 case undoMergeRoot:
1346 i := pass.idx
1347 l, r := po.children(i)
1348 po.changeroot(i, l.Target())
1349 po.roots = append(po.roots, r.Target())
1350
1351 default:
1352 panic(pass.typ)
1353 }
1354 }
1355
1356 if debugPoset && po.CheckEmpty() != nil {
1357 panic("poset not empty at the end of undo")
1358 }
1359 }
1360
View as plain text