// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssa import ( "fmt" ) // ---------------------------------------------------------------------------- // Sparse Conditional Constant Propagation // // Described in // Mark N. Wegman, F. Kenneth Zadeck: Constant Propagation with Conditional Branches. // TOPLAS 1991. // // This algorithm uses three level lattice for SSA value // // Top undefined // / | \ // .. 1 2 3 .. constant // \ | / // Bottom not constant // // It starts with optimistically assuming that all SSA values are initially Top // and then propagates constant facts only along reachable control flow paths. // Since some basic blocks are not visited yet, corresponding inputs of phi become // Top, we use the meet(phi) to compute its lattice. // // Top ∩ any = any // Bottom ∩ any = Bottom // ConstantA ∩ ConstantA = ConstantA // ConstantA ∩ ConstantB = Bottom // // Each lattice value is lowered most twice(Top to Constant, Constant to Bottom) // due to lattice depth, resulting in a fast convergence speed of the algorithm. // In this way, sccp can discover optimization opportunities that cannot be found // by just combining constant folding and constant propagation and dead code // elimination separately. // Three level lattice holds compile time knowledge about SSA value const ( top int8 = iota // undefined constant // constant bottom // not a constant ) type lattice struct { tag int8 // lattice type val *Value // constant value } type worklist struct { f *Func // the target function to be optimized out edges []Edge // propagate constant facts through edges uses []*Value // re-visiting set visited map[Edge]bool // visited edges latticeCells map[*Value]lattice // constant lattices defUse map[*Value][]*Value // def-use chains for some values defBlock map[*Value][]*Block // use blocks of def visitedBlock []bool // visited block } // sccp stands for sparse conditional constant propagation, it propagates constants // through CFG conditionally and applies constant folding, constant replacement and // dead code elimination all together. func sccp(f *Func) { var t worklist t.f = f t.edges = make([]Edge, 0) t.visited = make(map[Edge]bool) t.edges = append(t.edges, Edge{f.Entry, 0}) t.defUse = make(map[*Value][]*Value) t.defBlock = make(map[*Value][]*Block) t.latticeCells = make(map[*Value]lattice) t.visitedBlock = f.Cache.allocBoolSlice(f.NumBlocks()) defer f.Cache.freeBoolSlice(t.visitedBlock) // build it early since we rely heavily on the def-use chain later t.buildDefUses() // pick up either an edge or SSA value from worklilst, process it for { if len(t.edges) > 0 { edge := t.edges[0] t.edges = t.edges[1:] if _, exist := t.visited[edge]; !exist { dest := edge.b destVisited := t.visitedBlock[dest.ID] // mark edge as visited t.visited[edge] = true t.visitedBlock[dest.ID] = true for _, val := range dest.Values { if val.Op == OpPhi || !destVisited { t.visitValue(val) } } // propagates constants facts through CFG, taking condition test // into account if !destVisited { t.propagate(dest) } } continue } if len(t.uses) > 0 { use := t.uses[0] t.uses = t.uses[1:] t.visitValue(use) continue } break } // apply optimizations based on discovered constants constCnt, rewireCnt := t.replaceConst() if f.pass.debug > 0 { if constCnt > 0 || rewireCnt > 0 { fmt.Printf("Phase SCCP for %v : %v constants, %v dce\n", f.Name, constCnt, rewireCnt) } } } func equals(a, b lattice) bool { if a == b { // fast path return true } if a.tag != b.tag { return false } if a.tag == constant { // The same content of const value may be different, we should // compare with auxInt instead v1 := a.val v2 := b.val if v1.Op == v2.Op && v1.AuxInt == v2.AuxInt { return true } else { return false } } return true } // possibleConst checks if Value can be fold to const. For those Values that can // never become constants(e.g. StaticCall), we don't make futile efforts. func possibleConst(val *Value) bool { if isConst(val) { return true } switch val.Op { case OpCopy: return true case OpPhi: return true case // negate OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F, OpCom8, OpCom16, OpCom32, OpCom64, // math OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt, // conversion OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F, OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64, OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F, OpCvtBoolToUint8, OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32, OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64, // bit OpCtz8, OpCtz16, OpCtz32, OpCtz64, // mask OpSlicemask, // safety check OpIsNonNil, // not OpNot: return true case // add OpAdd64, OpAdd32, OpAdd16, OpAdd8, OpAdd32F, OpAdd64F, // sub OpSub64, OpSub32, OpSub16, OpSub8, OpSub32F, OpSub64F, // mul OpMul64, OpMul32, OpMul16, OpMul8, OpMul32F, OpMul64F, // div OpDiv32F, OpDiv64F, OpDiv8, OpDiv16, OpDiv32, OpDiv64, OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, OpMod8, OpMod16, OpMod32, OpMod64, OpMod8u, OpMod16u, OpMod32u, OpMod64u, // compare OpEq64, OpEq32, OpEq16, OpEq8, OpEq32F, OpEq64F, OpLess64, OpLess32, OpLess16, OpLess8, OpLess64U, OpLess32U, OpLess16U, OpLess8U, OpLess32F, OpLess64F, OpLeq64, OpLeq32, OpLeq16, OpLeq8, OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U, OpLeq32F, OpLeq64F, OpEqB, OpNeqB, // shift OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64, OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64, OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64, // safety check OpIsInBounds, OpIsSliceInBounds, // bit OpAnd8, OpAnd16, OpAnd32, OpAnd64, OpOr8, OpOr16, OpOr32, OpOr64, OpXor8, OpXor16, OpXor32, OpXor64: return true default: return false } } func (t *worklist) getLatticeCell(val *Value) lattice { if !possibleConst(val) { // they are always worst return lattice{bottom, nil} } lt, exist := t.latticeCells[val] if !exist { return lattice{top, nil} // optimistically for un-visited value } return lt } func isConst(val *Value) bool { switch val.Op { case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool, OpConst32F, OpConst64F: return true default: return false } } // buildDefUses builds def-use chain for some values early, because once the // lattice of a value is changed, we need to update lattices of use. But we don't // need all uses of it, only uses that can become constants would be added into // re-visit worklist since no matter how many times they are revisited, uses which // can't become constants lattice remains unchanged, i.e. Bottom. func (t *worklist) buildDefUses() { for _, block := range t.f.Blocks { for _, val := range block.Values { for _, arg := range val.Args { // find its uses, only uses that can become constants take into account if possibleConst(arg) && possibleConst(val) { if _, exist := t.defUse[arg]; !exist { t.defUse[arg] = make([]*Value, 0, arg.Uses) } t.defUse[arg] = append(t.defUse[arg], val) } } } for _, ctl := range block.ControlValues() { // for control values that can become constants, find their use blocks if possibleConst(ctl) { t.defBlock[ctl] = append(t.defBlock[ctl], block) } } } } // addUses finds all uses of value and appends them into work list for further process func (t *worklist) addUses(val *Value) { for _, use := range t.defUse[val] { if val == use { // Phi may refer to itself as uses, ignore them to avoid re-visiting phi // for performance reason continue } t.uses = append(t.uses, use) } for _, block := range t.defBlock[val] { if t.visitedBlock[block.ID] { t.propagate(block) } } } // meet meets all of phi arguments and computes result lattice func (t *worklist) meet(val *Value) lattice { optimisticLt := lattice{top, nil} for i := 0; i < len(val.Args); i++ { edge := Edge{val.Block, i} // If incoming edge for phi is not visited, assume top optimistically. // According to rules of meet: // Top ∩ any = any // Top participates in meet() but does not affect the result, so here // we will ignore Top and only take other lattices into consideration. if _, exist := t.visited[edge]; exist { lt := t.getLatticeCell(val.Args[i]) if lt.tag == constant { if optimisticLt.tag == top { optimisticLt = lt } else { if !equals(optimisticLt, lt) { // ConstantA ∩ ConstantB = Bottom return lattice{bottom, nil} } } } else if lt.tag == bottom { // Bottom ∩ any = Bottom return lattice{bottom, nil} } else { // Top ∩ any = any } } else { // Top ∩ any = any } } // ConstantA ∩ ConstantA = ConstantA or Top ∩ any = any return optimisticLt } func computeLattice(f *Func, val *Value, args ...*Value) lattice { // In general, we need to perform constant evaluation based on constant args: // // res := lattice{constant, nil} // switch op { // case OpAdd16: // res.val = newConst(argLt1.val.AuxInt16() + argLt2.val.AuxInt16()) // case OpAdd32: // res.val = newConst(argLt1.val.AuxInt32() + argLt2.val.AuxInt32()) // case OpDiv8: // if !isDivideByZero(argLt2.val.AuxInt8()) { // res.val = newConst(argLt1.val.AuxInt8() / argLt2.val.AuxInt8()) // } // ... // } // // However, this would create a huge switch for all opcodes that can be // evaluated during compile time. Moreover, some operations can be evaluated // only if its arguments satisfy additional conditions(e.g. divide by zero). // It's fragile and error prone. We did a trick by reusing the existing rules // in generic rules for compile-time evaluation. But generic rules rewrite // original value, this behavior is undesired, because the lattice of values // may change multiple times, once it was rewritten, we lose the opportunity // to change it permanently, which can lead to errors. For example, We cannot // change its value immediately after visiting Phi, because some of its input // edges may still not be visited at this moment. constValue := f.newValue(val.Op, val.Type, f.Entry, val.Pos) constValue.AddArgs(args...) matched := rewriteValuegeneric(constValue) if matched { if isConst(constValue) { return lattice{constant, constValue} } } // Either we can not match generic rules for given value or it does not // satisfy additional constraints(e.g. divide by zero), in these cases, clean // up temporary value immediately in case they are not dominated by their args. constValue.reset(OpInvalid) return lattice{bottom, nil} } func (t *worklist) visitValue(val *Value) { if !possibleConst(val) { // fast fail for always worst Values, i.e. there is no lowering happen // on them, their lattices must be initially worse Bottom. return } oldLt := t.getLatticeCell(val) defer func() { // re-visit all uses of value if its lattice is changed newLt := t.getLatticeCell(val) if !equals(newLt, oldLt) { if int8(oldLt.tag) > int8(newLt.tag) { t.f.Fatalf("Must lower lattice\n") } t.addUses(val) } }() switch val.Op { // they are constant values, aren't they? case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool, OpConst32F, OpConst64F: //TODO: support ConstNil ConstString etc t.latticeCells[val] = lattice{constant, val} // lattice value of copy(x) actually means lattice value of (x) case OpCopy: t.latticeCells[val] = t.getLatticeCell(val.Args[0]) // phi should be processed specially case OpPhi: t.latticeCells[val] = t.meet(val) // fold 1-input operations: case // negate OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F, OpCom8, OpCom16, OpCom32, OpCom64, // math OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt, // conversion OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F, OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64, OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F, OpCvtBoolToUint8, OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32, OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64, // bit OpCtz8, OpCtz16, OpCtz32, OpCtz64, // mask OpSlicemask, // safety check OpIsNonNil, // not OpNot: lt1 := t.getLatticeCell(val.Args[0]) if lt1.tag == constant { // here we take a shortcut by reusing generic rules to fold constants t.latticeCells[val] = computeLattice(t.f, val, lt1.val) } else { t.latticeCells[val] = lattice{lt1.tag, nil} } // fold 2-input operations case // add OpAdd64, OpAdd32, OpAdd16, OpAdd8, OpAdd32F, OpAdd64F, // sub OpSub64, OpSub32, OpSub16, OpSub8, OpSub32F, OpSub64F, // mul OpMul64, OpMul32, OpMul16, OpMul8, OpMul32F, OpMul64F, // div OpDiv32F, OpDiv64F, OpDiv8, OpDiv16, OpDiv32, OpDiv64, OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, //TODO: support div128u // mod OpMod8, OpMod16, OpMod32, OpMod64, OpMod8u, OpMod16u, OpMod32u, OpMod64u, // compare OpEq64, OpEq32, OpEq16, OpEq8, OpEq32F, OpEq64F, OpLess64, OpLess32, OpLess16, OpLess8, OpLess64U, OpLess32U, OpLess16U, OpLess8U, OpLess32F, OpLess64F, OpLeq64, OpLeq32, OpLeq16, OpLeq8, OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U, OpLeq32F, OpLeq64F, OpEqB, OpNeqB, // shift OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64, OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64, OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64, // safety check OpIsInBounds, OpIsSliceInBounds, // bit OpAnd8, OpAnd16, OpAnd32, OpAnd64, OpOr8, OpOr16, OpOr32, OpOr64, OpXor8, OpXor16, OpXor32, OpXor64: lt1 := t.getLatticeCell(val.Args[0]) lt2 := t.getLatticeCell(val.Args[1]) if lt1.tag == constant && lt2.tag == constant { // here we take a shortcut by reusing generic rules to fold constants t.latticeCells[val] = computeLattice(t.f, val, lt1.val, lt2.val) } else { if lt1.tag == bottom || lt2.tag == bottom { t.latticeCells[val] = lattice{bottom, nil} } else { t.latticeCells[val] = lattice{top, nil} } } default: // Any other type of value cannot be a constant, they are always worst(Bottom) } } // propagate propagates constants facts through CFG. If the block has single successor, // add the successor anyway. If the block has multiple successors, only add the // branch destination corresponding to lattice value of condition value. func (t *worklist) propagate(block *Block) { switch block.Kind { case BlockExit, BlockRet, BlockRetJmp, BlockInvalid: // control flow ends, do nothing then break case BlockDefer: // we know nothing about control flow, add all branch destinations t.edges = append(t.edges, block.Succs...) case BlockFirst: fallthrough // always takes the first branch case BlockPlain: t.edges = append(t.edges, block.Succs[0]) case BlockIf, BlockJumpTable: cond := block.ControlValues()[0] condLattice := t.getLatticeCell(cond) if condLattice.tag == bottom { // we know nothing about control flow, add all branch destinations t.edges = append(t.edges, block.Succs...) } else if condLattice.tag == constant { // add branchIdx destinations depends on its condition var branchIdx int64 if block.Kind == BlockIf { branchIdx = 1 - condLattice.val.AuxInt } else { branchIdx = condLattice.val.AuxInt } t.edges = append(t.edges, block.Succs[branchIdx]) } else { // condition value is not visited yet, don't propagate it now } default: t.f.Fatalf("All kind of block should be processed above.") } } // rewireSuccessor rewires corresponding successors according to constant value // discovered by previous analysis. As the result, some successors become unreachable // and thus can be removed in further deadcode phase func rewireSuccessor(block *Block, constVal *Value) bool { switch block.Kind { case BlockIf: block.removeEdge(int(constVal.AuxInt)) block.Kind = BlockPlain block.Likely = BranchUnknown block.ResetControls() return true case BlockJumpTable: // Remove everything but the known taken branch. idx := int(constVal.AuxInt) if idx < 0 || idx >= len(block.Succs) { // This can only happen in unreachable code, // as an invariant of jump tables is that their // input index is in range. // See issue 64826. return false } block.swapSuccessorsByIdx(0, idx) for len(block.Succs) > 1 { block.removeEdge(1) } block.Kind = BlockPlain block.Likely = BranchUnknown block.ResetControls() return true default: return false } } // replaceConst will replace non-constant values that have been proven by sccp // to be constants. func (t *worklist) replaceConst() (int, int) { constCnt, rewireCnt := 0, 0 for val, lt := range t.latticeCells { if lt.tag == constant { if !isConst(val) { if t.f.pass.debug > 0 { fmt.Printf("Replace %v with %v\n", val.LongString(), lt.val.LongString()) } val.reset(lt.val.Op) val.AuxInt = lt.val.AuxInt constCnt++ } // If const value controls this block, rewires successors according to its value ctrlBlock := t.defBlock[val] for _, block := range ctrlBlock { if rewireSuccessor(block, lt.val) { rewireCnt++ if t.f.pass.debug > 0 { fmt.Printf("Rewire %v %v successors\n", block.Kind, block) } } } } } return constCnt, rewireCnt }