63 class NullableVector {
77 assert(
valid &&
"Accessing invalid vector!");
89 static int newVarCount = 0;
99 auto newClause = mk<Clause>();
100 newClause->setSrcLoc(clause->getSrcLoc());
101 auto clauseHead = mk<Atom>(clause->getHead()->getQualifiedName());
104 for (
Literal* lit : clause->getBodyLiterals()) {
109 for (Argument* arg : clause->getHead()->getArguments()) {
110 if (
auto* constant =
dynamic_cast<Constant*
>(arg)) {
112 std::stringstream newVar;
113 newVar <<
"<new_var_" << newVarCount++ <<
">";
114 clauseHead->addArgument(mk<ast::Variable>(newVar.str()));
125 newClause->setHead(std::move(clauseHead));
141 struct M :
public NodeMapper {
142 mutable bool changed =
false;
143 const std::set<QualifiedName> inlinedRelations;
144 bool replaceUnderscores;
146 M(std::set<QualifiedName> inlinedRelations,
bool replaceUnderscores)
147 : inlinedRelations(
std::move(inlinedRelations)), replaceUnderscores(replaceUnderscores) {}
150 static int underscoreCount = 0;
152 if (!replaceUnderscores) {
154 if (
auto* atom =
dynamic_cast<Atom*
>(node.get())) {
155 if (inlinedRelations.find(atom->getQualifiedName()) != inlinedRelations.end()) {
158 M replace(inlinedRelations,
true);
159 node->apply(replace);
160 changed |= replace.changed;
164 }
else if (isA<UnnamedVariable>(node.get())) {
168 std::stringstream newVarName;
169 newVarName <<
"<underscore_" << underscoreCount++ <<
">";
171 return mk<ast::Variable>(newVarName.str());
180 std::set<QualifiedName> inlinedRelations;
181 for (Relation*
rel : program.getRelations()) {
183 inlinedRelations.insert(
rel->getQualifiedName());
188 M update(inlinedRelations,
false);
189 program.apply(update);
190 return update.changed;
197 bool foundInlinedAtom =
false;
202 foundInlinedAtom =
true;
206 return foundInlinedAtom;
221 for (
size_t i = 0;
i <
sub.size();
i++) {
222 auto currPair =
sub[
i];
232 }
else if (isA<Constant>(
lhs) && isA<Constant>(
rhs)) {
236 }
else if (isA<RecordInit>(
lhs) && isA<RecordInit>(
rhs)) {
240 std::vector<Argument*> lhsArgs =
static_cast<RecordInit*
>(
lhs)->getArguments();
241 std::vector<Argument*> rhsArgs =
static_cast<RecordInit*
>(
rhs)->getArguments();
243 if (lhsArgs.size() != rhsArgs.size()) {
249 for (
size_t i = 0;
i < lhsArgs.size();
i++) {
250 sub.push_back(std::make_pair(lhsArgs[
i], rhsArgs[
i]));
256 }
else if ((isA<RecordInit>(
lhs) && isA<Constant>(
rhs)) ||
257 (isA<Constant>(
lhs) && isA<RecordInit>(
rhs))) {
272 NullableVector<std::pair<Argument*, Argument*>>
unifyAtoms(Atom* first, Atom* second) {
273 std::vector<std::pair<Argument*, Argument*>> substitution;
275 std::vector<Argument*> firstArgs = first->getArguments();
276 std::vector<Argument*> secondArgs = second->getArguments();
279 for (
size_t i = 0;
i < firstArgs.size();
i++) {
280 substitution.push_back(std::make_pair(firstArgs[
i], secondArgs[
i]));
286 return NullableVector<std::pair<Argument*, Argument*>>(substitution);
289 return NullableVector<std::pair<Argument*, Argument*>>();
298 std::pair<NullableVector<Literal*>, std::vector<BinaryConstraint*>>
inlineBodyLiterals(
299 Atom* atom, Clause* atomInlineClause) {
300 bool changed =
false;
301 std::vector<Literal*> addedLits;
302 std::vector<BinaryConstraint*> constraints;
306 static int inlineCount = 0;
313 VariableRenamer(
int varnum) : varnum(varnum) {}
318 std::stringstream newName;
319 newName <<
"<inlined_" << var->getName() <<
"_" << varnum <<
">";
320 newVar->setName(newName.str());
328 VariableRenamer update(inlineCount);
329 atomClause->apply(update);
334 NullableVector<std::pair<Argument*, Argument*>> res =
unifyAtoms(atomClause->getHead(), atom);
337 for (std::pair<Argument*, Argument*> pair : res.getVector()) {
339 constraints.push_back(
new BinaryConstraint(
344 for (Literal* lit : atomClause->getBodyLiterals()) {
345 addedLits.push_back(lit->clone());
350 return std::make_pair(NullableVector<Literal*>(addedLits), constraints);
352 return std::make_pair(NullableVector<Literal*>(), constraints);
360 if (
auto* atom =
dynamic_cast<Atom*
>(lit)) {
363 }
else if (
auto* neg =
dynamic_cast<Negation*
>(lit)) {
364 Atom* atom = neg->getAtom()->clone();
366 }
else if (
auto* cons =
dynamic_cast<Constraint*
>(lit)) {
372 fatal(
"unsupported literal type: %s", *lit);
380 std::vector<std::vector<Literal*>> negation;
383 if (litGroups.empty()) {
387 std::vector<Literal*> litGroup = litGroups[0];
388 if (litGroups.size() == 1) {
390 for (
auto&
i : litGroup) {
391 std::vector<Literal*> newVec;
393 negation.push_back(newVec);
402 std::vector<std::vector<Literal*>>(litGroups.begin() + 1, litGroups.end()));
406 for (Literal* lhsLit : litGroup) {
407 for (std::vector<Literal*> rhsVec : combinedRHS) {
408 std::vector<Literal*> newVec;
411 for (Literal* lit : rhsVec) {
412 newVec.push_back(lit->clone());
415 negation.push_back(newVec);
419 for (std::vector<Literal*> rhsVec : combinedRHS) {
420 for (Literal* lit : rhsVec) {
437 std::vector<std::vector<Literal*>> addedBodyLiterals;
438 std::vector<std::vector<BinaryConstraint*>> addedConstraints;
443 std::pair<NullableVector<Literal*>, std::vector<BinaryConstraint*>> inlineResult =
446 std::vector<BinaryConstraint*> currConstraints = inlineResult.second;
448 if (!replacementBodyLiterals.isValid()) {
453 addedBodyLiterals.push_back(replacementBodyLiterals.getVector());
454 addedConstraints.push_back(currConstraints);
462 for (
auto& negatedAddedBodyLiteral : negatedAddedBodyLiterals) {
463 for (std::vector<BinaryConstraint*> constraintGroup : addedConstraints) {
465 negatedAddedBodyLiteral.push_back(constraint->clone());
471 for (std::vector<Literal*> litGroup : addedBodyLiterals) {
472 for (Literal* lit : litGroup) {
476 for (std::vector<BinaryConstraint*> consGroup : addedConstraints) {
477 for (Constraint* cons : consGroup) {
482 return negatedAddedBodyLiterals;
489 static int varCount = 0;
492 struct M :
public NodeMapper {
494 M(
int varnum) : varnum(varnum) {}
498 std::stringstream newName;
499 newName << var->getName() <<
"-v" << varnum;
500 newVar->setName(newName.str());
520 if (aggrs.size() == 1) {
538 bool changed =
false;
539 std::vector<Argument*> versions;
543 if (
const auto* aggr =
dynamic_cast<const Aggregator*
>(arg)) {
545 if (aggr->getTargetExpression() !=
nullptr) {
549 if (argumentVersions.isValid()) {
554 for (
Argument* newArg : argumentVersions.getVector()) {
557 for (
Literal* lit : aggr->getBodyLiterals()) {
560 newAggr->setBody(std::move(newBody));
561 versions.push_back(newAggr);
569 std::vector<Literal*> bodyLiterals = aggr->getBodyLiterals();
570 for (
size_t i = 0;
i < bodyLiterals.size();
i++) {
575 if (literalVersions.isValid()) {
582 std::vector<Aggregator*> aggrVersions;
583 for (std::vector<Literal*> inlineVersions : literalVersions.getVector()) {
585 if (aggr->getTargetExpression() !=
nullptr) {
588 auto* newAggr =
new Aggregator(aggr->getBaseOperator(), std::move(target));
592 for (
size_t j = 0;
j < bodyLiterals.size();
j++) {
599 for (
Literal* addedLit : inlineVersions) {
603 newAggr->setBody(std::move(newBody));
604 aggrVersions.push_back(newAggr);
638 }
else if (
const auto* functor =
dynamic_cast<const Functor*
>(arg)) {
640 for (
auto funArg : functor->getArguments()) {
644 if (argumentVersions.isValid()) {
646 for (
Argument* newArgVersion : argumentVersions.getVector()) {
650 for (
auto& functorArg : functor->getArguments()) {
652 argsCopy.emplace_back(newArgVersion);
654 argsCopy.emplace_back(functorArg->clone());
658 if (
const auto* intrFunc =
dynamic_cast<const IntrinsicFunctor*
>(arg)) {
660 new IntrinsicFunctor(intrFunc->getBaseFunctionOp(), std::move(argsCopy));
661 newFunctor->setSrcLoc(functor->getSrcLoc());
662 versions.push_back(newFunctor);
663 }
else if (
const auto* userFunc =
dynamic_cast<const UserDefinedFunctor*
>(arg)) {
664 auto* newFunctor =
new UserDefinedFunctor(userFunc->getName(), std::move(argsCopy));
665 newFunctor->setSrcLoc(userFunc->getSrcLoc());
666 versions.push_back(newFunctor);
674 }
else if (
const auto* cast =
dynamic_cast<const ast::TypeCast*
>(arg)) {
675 NullableVector<Argument*> argumentVersions =
getInlinedArgument(program, cast->getValue());
676 if (argumentVersions.isValid()) {
678 for (Argument* newArg : argumentVersions.getVector()) {
679 Argument* newTypeCast =
new ast::TypeCast(Own<Argument>(newArg), cast->getType());
680 versions.push_back(newTypeCast);
683 }
else if (
const auto* record =
dynamic_cast<const RecordInit*
>(arg)) {
684 std::vector<Argument*> recordArguments = record->getArguments();
685 for (
size_t i = 0;
i < recordArguments.size();
i++) {
686 Argument* currentRecArg = recordArguments[
i];
687 NullableVector<Argument*> argumentVersions =
getInlinedArgument(program, currentRecArg);
688 if (argumentVersions.isValid()) {
690 for (Argument* newArgumentVersion : argumentVersions.getVector()) {
691 auto* newRecordArg =
new RecordInit();
692 for (
size_t j = 0;
j < recordArguments.size();
j++) {
694 newRecordArg->addArgument(Own<Argument>(newArgumentVersion));
699 versions.push_back(newRecordArg);
712 return NullableVector<Argument*>(versions);
715 return NullableVector<Argument*>();
723 NullableVector<Atom*>
getInlinedAtom(Program& program, Atom& atom) {
724 bool changed =
false;
725 std::vector<Atom*> versions;
728 std::vector<Argument*> arguments = atom.getArguments();
729 for (
size_t i = 0;
i < arguments.size();
i++) {
734 if (argumentVersions.isValid()) {
739 for (
Argument* newArgument : argumentVersions.getVector()) {
740 auto args = atom.getArguments();
742 for (
size_t j = 0;
j < args.size();
j++) {
744 newArgs.emplace_back(newArgument);
746 newArgs.emplace_back(args[
j]->
clone());
749 auto* newAtom =
new Atom(atom.getQualifiedName(), std::move(newArgs), atom.getSrcLoc());
750 versions.push_back(newAtom);
762 return NullableVector<Atom*>(versions);
765 return NullableVector<Atom*>();
779 NullableVector<std::vector<Literal*>>
getInlinedLiteral(Program& program, Literal* lit) {
780 bool inlined =
false;
781 bool changed =
false;
783 std::vector<std::vector<Literal*>> addedBodyLiterals;
784 std::vector<Literal*> versions;
786 if (
auto* atom =
dynamic_cast<Atom*
>(lit)) {
799 std::pair<NullableVector<Literal*>, std::vector<BinaryConstraint*>> inlineResult =
802 std::vector<BinaryConstraint*> currConstraints = inlineResult.second;
804 if (!replacementBodyLiterals.isValid()) {
811 std::vector<Literal*> bodyResult = replacementBodyLiterals.getVector();
814 bodyResult.push_back(cons);
817 addedBodyLiterals.push_back(bodyResult);
822 if (atomVersions.isValid()) {
825 for (
Atom* newAtom : atomVersions.getVector()) {
826 versions.push_back(newAtom);
830 }
else if (
auto neg =
dynamic_cast<Negation*
>(lit)) {
832 Atom* atom = neg->getAtom();
833 NullableVector<std::vector<Literal*>> atomVersions =
getInlinedLiteral(program, atom);
835 if (atomVersions.isValid()) {
839 if (atomVersions.getVector().empty()) {
854 if (atomVersions.isValid()) {
855 for (
const auto& curVec : atomVersions.getVector()) {
856 for (
auto* cur : curVec) {
861 }
else if (
auto* constraint =
dynamic_cast<BinaryConstraint*
>(lit)) {
862 NullableVector<Argument*> lhsVersions =
getInlinedArgument(program, constraint->getLHS());
863 if (lhsVersions.isValid()) {
865 for (Argument* newLhs : lhsVersions.getVector()) {
866 Literal* newLit =
new BinaryConstraint(constraint->getBaseOperator(), Own<Argument>(newLhs),
868 versions.push_back(newLit);
871 NullableVector<Argument*> rhsVersions =
getInlinedArgument(program, constraint->getRHS());
872 if (rhsVersions.isValid()) {
874 for (Argument* newRhs : rhsVersions.getVector()) {
875 Literal* newLit =
new BinaryConstraint(constraint->getBaseOperator(),
877 versions.push_back(newLit);
886 for (Literal* version : versions) {
887 std::vector<Literal*> newBody;
888 newBody.push_back(version);
889 addedBodyLiterals.push_back(newBody);
895 return NullableVector<std::vector<Literal*>>(addedBodyLiterals);
897 return NullableVector<std::vector<Literal*>>();
905 std::vector<Clause*>
getInlinedClause(Program& program,
const Clause& clause) {
906 bool changed =
false;
907 std::vector<Clause*> versions;
913 if (headVersions.isValid()) {
918 for (
Atom* newHead : headVersions.getVector()) {
919 auto* newClause =
new Clause();
920 newClause->setSrcLoc(clause.getSrcLoc());
925 for (
Literal* lit : clause.getBodyLiterals()) {
929 versions.push_back(newClause);
936 std::vector<Literal*> bodyLiterals = clause.getBodyLiterals();
937 for (
size_t i = 0;
i < bodyLiterals.size();
i++) {
938 Literal* currLit = bodyLiterals[
i];
947 NullableVector<std::vector<Literal*>> litVersions =
getInlinedLiteral(program, currLit);
949 if (litVersions.isValid()) {
955 std::vector<std::vector<Literal*>> bodyVersions = litVersions.getVector();
958 auto baseClause = Own<Clause>(
cloneHead(&clause));
959 for (Literal* oldLit : bodyLiterals) {
960 if (currLit != oldLit) {
965 for (std::vector<Literal*> body : bodyVersions) {
966 Clause* replacementClause = baseClause->clone();
970 for (Literal* newLit : body) {
971 replacementClause->addToBody(Own<Literal>(newLit));
974 versions.push_back(replacementClause);
987 std::vector<Clause*> ret;
988 ret.push_back(clause.clone());
997 bool changed =
false;
998 Program& program = translationUnit.getProgram();
1010 bool clausesChanged =
true;
1011 while (clausesChanged) {
1012 std::set<Clause*> clausesToDelete;
1013 clausesChanged =
false;
1029 clausesToDelete.insert(clause);
1030 for (
Clause* replacementClause : newClauses) {
1031 program.addClause(
Own<Clause>(replacementClause));
1035 clausesChanged =
true;
1042 for (
const Clause* clause : clausesToDelete) {
1043 program.removeClause(clause);