100 using namespace analysis;
102 struct SemanticCheckerImpl {
104 SemanticCheckerImpl(TranslationUnit& tu);
107 const IOTypeAnalysis& ioTypes = *tu.getAnalysis<IOTypeAnalysis>();
117 void checkAtom(
const Atom& atom);
118 void checkLiteral(
const Literal& literal);
122 void checkConstant(
const Argument& argument);
123 void checkFact(
const Clause& fact);
124 void checkClause(
const Clause& clause);
125 void checkComplexRule(std::set<const Clause*> multiRule);
130 void checkBranchInits();
132 void checkNamespaces();
134 void checkWitnessProblem();
135 void checkInlining();
146 std::vector<std::string> suppressedRelations =
149 if (std::find(suppressedRelations.begin(), suppressedRelations.end(),
"*") !=
150 suppressedRelations.end()) {
157 for (
auto& relname : suppressedRelations) {
158 const std::vector<std::string> comps =
splitString(relname,
'.');
159 if (!comps.empty()) {
162 for (
size_t i = 1;
i < comps.size();
i++) {
163 relid.append(comps[
i]);
188 std::map<SrcLocation, std::set<const Clause*>> multiRuleMap;
191 multiRuleMap[clause->getSrcLoc()].insert(clause);
195 for (
const auto& multiRule : multiRuleMap) {
205 GroundedTermsChecker().verify(
tu);
208 TypeChecker().verify(
tu);
217 const Literal* foundLiteral =
nullptr;
222 std::set<const Relation*, NameComparison> sortedRelSet(relSet.begin(), relSet.end());
224 std::string relationsListStr =
toString(
join(sortedRelSet,
",",
225 [](std::ostream& out,
const Relation* r) { out << r->getQualifiedName(); }));
226 std::vector<DiagnosticMessage> messages;
227 messages.push_back(DiagnosticMessage(
228 "Relation " +
toString(cur->getQualifiedName()), cur->getSrcLoc()));
229 std::string negOrAgg = hasNegation ?
"negation" :
"aggregation";
231 DiagnosticMessage(
"has cyclic " + negOrAgg, foundLiteral->getSrcLoc()));
233 DiagnosticMessage(
"Unable to stratify relation(s) {" + relationsListStr +
"}"),
250 if (r->getArity() != atom.getArity()) {
252 "Mismatching arity of relation " +
toString(atom.getQualifiedName()), atom.getSrcLoc());
255 for (
const Argument* arg : atom.getArguments()) {
262 auto* type = sumTypesBranches.getType(adt.getConstructor());
263 if (type == nullptr) {
264 report.addError(
"Undeclared branch", adt.getSrcLoc());
268 size_t declaredArity =
269 as<analysis::AlgebraicDataType>(
type)->getBranchTypes(adt.getConstructor()).size();
270 size_t branchArity = adt.getArguments().size();
271 if (declaredArity != branchArity) {
285 std::set<const UnnamedVariable*> getUnnamedVariables(
const Node& node) {
286 std::set<const UnnamedVariable*> unnamedInAggregates;
288 visitDepthFirst(agg, [&](
const UnnamedVariable& var) { unnamedInAggregates.insert(&var); });
291 std::set<const UnnamedVariable*> unnamed;
293 if (!
contains(unnamedInAggregates, &var)) {
294 unnamed.insert(&var);
303 void SemanticCheckerImpl::checkLiteral(
const Literal& literal) {
305 if (
const auto* atom = as<Atom>(literal)) {
309 if (
const auto* neg = as<Negation>(literal)) {
310 checkAtom(*neg->getAtom());
313 if (
const auto* constraint = as<BinaryConstraint>(literal)) {
314 checkArgument(*constraint->getLHS());
315 checkArgument(*constraint->getRHS());
317 std::set<const UnnamedVariable*> unnamedInRecord;
320 if (auto* unnamed = as<UnnamedVariable>(arg)) {
321 unnamedInRecord.insert(unnamed);
327 if (isA<Aggregator>(*constraint->getLHS()) || isA<Aggregator>(*constraint->getRHS())) {
331 for (
auto* unnamed : getUnnamedVariables(*constraint)) {
332 if (!
contains(unnamedInRecord, unnamed)) {
333 report.
addError(
"Underscore in binary relation", unnamed->getSrcLoc());
344 bool SemanticCheckerImpl::isDependent(
const Clause& agg1,
const Clause& agg2) {
347 bool dependent =
false;
354 if (var == searchVar) {
355 matchingVarPtr = &var;
360 if (matchingVarPtr !=
nullptr) {
361 if (!groundedInAgg1[&searchVar] && groundedInAgg2[matchingVarPtr]) {
369 void SemanticCheckerImpl::checkAggregator(
const Aggregator& aggregator) {
370 auto& report = tu.getErrorReport();
371 const Program& program = tu.getProgram();
372 Clause dummyClauseAggregator;
376 if (candidateAggregate != aggregator) {
391 if (isDependent(dummyClauseAggregator, dummyClauseOther) &&
392 isDependent(dummyClauseOther, dummyClauseAggregator)) {
393 report.
addError(
"Mutually dependent aggregate", aggregator.getSrcLoc());
398 for (Literal* literal : aggregator.getBodyLiterals()) {
399 checkLiteral(*literal);
403 void SemanticCheckerImpl::checkArgument(
const Argument& arg) {
404 if (
const auto* agg =
dynamic_cast<const Aggregator*
>(&arg)) {
405 checkAggregator(*agg);
406 }
else if (
const auto* func =
dynamic_cast<const Functor*
>(&arg)) {
407 for (
auto arg : func->getArguments()) {
419 bool isConstantArgument(
const Argument* arg) {
420 assert(arg !=
nullptr);
422 if (isA<ast::Variable>(arg) || isA<UnnamedVariable>(arg)) {
424 }
else if (isA<UserDefinedFunctor>(arg)) {
426 }
else if (isA<Counter>(arg)) {
428 }
else if (
auto* typeCast = as<ast::TypeCast>(arg)) {
429 return isConstantArgument(typeCast->getValue());
430 }
else if (
auto* term = as<Term>(arg)) {
432 return all_of(term->getArguments(), isConstantArgument);
433 }
else if (isA<Constant>(arg)) {
436 fatal(
"unsupported argument type: %s",
typeid(arg).name());
443 void SemanticCheckerImpl::checkFact(
const Clause& fact) {
446 Atom* head = fact.getHead();
447 if (head ==
nullptr) {
452 if (
rel ==
nullptr) {
457 for (
auto* arg : head->getArguments()) {
458 if (!isConstantArgument(arg)) {
459 report.
addError(
"Argument in fact is not constant", arg->getSrcLoc());
464 void SemanticCheckerImpl::checkClause(
const Clause& clause) {
466 checkAtom(*clause.getHead());
469 for (
auto* unnamed : getUnnamedVariables(*clause.getHead())) {
470 report.
addError(
"Underscore in head of rule", unnamed->getSrcLoc());
474 for (
Literal* lit : clause.getBodyLiterals()) {
486 std::map<std::string, int> var_count;
487 std::map<std::string, const ast::Variable*> var_pos;
489 var_count[var.getName()]++;
490 var_pos[var.getName()] = &var;
492 for (
const auto& cur : var_count) {
493 int numAppearances = cur.second;
494 const auto& varName = cur.first;
495 const auto& varLocation = var_pos[varName]->getSrcLoc();
496 if (varName[0] ==
'_') {
497 assert(varName.size() > 1 &&
"named variable should not be a single underscore");
498 if (numAppearances > 1) {
499 report.
addWarning(
"Variable " + varName +
" marked as singleton but occurs more than once",
506 if (clause.getExecutionPlan() !=
nullptr) {
507 auto numAtoms = getBodyLiterals<Atom>(clause).size();
508 for (
const auto& cur : clause.getExecutionPlan()->getOrders()) {
509 bool isComplete =
true;
510 auto order = cur.second->getOrder();
511 for (
unsigned i = 1;
i <= order.size();
i++) {
517 if (order.size() != numAtoms || !isComplete) {
518 report.
addError(
"Invalid execution order in plan", cur.second->getSrcLoc());
524 if (recursiveClauses.
recursive(&clause)) {
526 report.
addError(
"Auto-increment functor in a recursive rule", ctr.getSrcLoc());
531 void SemanticCheckerImpl::checkComplexRule(std::set<const Clause*> multiRule) {
532 std::map<std::string, int> var_count;
533 std::map<std::string, const ast::Variable*> var_pos;
539 for (
auto literal : (*multiRule.begin())->getBodyLiterals()) {
541 var_count[var.getName()]++;
542 var_pos[var.getName()] = &var;
547 for (
auto clause : multiRule) {
549 var_count[var.getName()]++;
550 var_pos[var.getName()] = &var;
555 for (
const auto& cur : var_count) {
556 int numAppearances = cur.second;
557 const auto& varName = cur.first;
558 const auto& varLocation = var_pos[varName]->getSrcLoc();
559 if (varName[0] !=
'_' && numAppearances == 1) {
560 report.
addWarning(
"Variable " + varName +
" only occurs once", varLocation);
565 void SemanticCheckerImpl::checkRelationDeclaration(
const Relation&
relation) {
566 const auto& attributes =
relation.getAttributes();
567 assert(attributes.size() ==
relation.
getArity() &&
"mismatching attribute size and arity");
570 Attribute* attr = attributes[
i];
571 auto&& typeName = attr->getTypeName();
572 auto* existingType =
getIf(program.getTypes(),
573 [&](
const ast::Type*
type) { return type->getQualifiedName() == typeName; });
581 for (
size_t j = 0;
j <
i;
j++) {
582 if (attr->getName() == attributes[
j]->getName()) {
583 report.
addError(
tfm::format(
"Doubly defined attribute name %s", *attr), attr->getSrcLoc());
589 void SemanticCheckerImpl::checkRelation(
const Relation&
relation) {
592 const auto& attributes =
relation.getAttributes();
593 assert(attributes.size() == 2 &&
"mismatching attribute size and arity");
594 if (attributes[0]->getTypeName() != attributes[1]->getTypeName()) {
601 "Equivalence relation " +
toString(
relation.getQualifiedName()) +
" is not binary",
617 void SemanticCheckerImpl::checkIO() {
618 auto checkIO = [&](
const Directive* directive) {
619 auto* r =
getRelation(program, directive->getQualifiedName());
622 "Undefined relation " +
toString(directive->getQualifiedName()), directive->getSrcLoc());
625 for (
const auto* directive : program.getDirectives()) {
642 TranslationUnit& tu,
const Clause& clause,
const Aggregator& aggregate) {
643 std::vector<SrcLocation> invalidWitnessLocations;
646 return invalidWitnessLocations;
649 auto aggregateSubclause = mk<Clause>();
650 aggregateSubclause->setHead(mk<Atom>(
"*"));
651 for (
const Literal* lit : aggregate.getBodyLiterals()) {
654 struct InnerAggregateMasker :
public NodeMapper {
655 mutable int numReplaced = 0;
657 if (isA<Aggregator>(node.get())) {
658 std::string newVariableName =
"+aggr_var_" +
toString(numReplaced++);
659 return mk<Variable>(newVariableName);
665 InnerAggregateMasker update;
666 aggregateSubclause->apply(update);
676 for (
const auto& witness : witnesses) {
678 if (var.getName() == witness) {
679 invalidWitnessLocations.push_back(var.getSrcLoc());
683 return invalidWitnessLocations;
686 void SemanticCheckerImpl::checkWitnessProblem() {
694 "Witness problem: argument grounded by an aggregator's inner scope is used "
696 "outer scope in a count/sum/mean aggregate",
707 std::vector<QualifiedName>
findInlineCycle(
const PrecedenceGraphAnalysis& precedenceGraph,
708 std::map<const Relation*, const Relation*>& origins,
const Relation* current,
RelationSet& unvisited,
710 std::vector<QualifiedName> result;
712 if (current ==
nullptr) {
715 if (unvisited.empty()) {
721 current = *unvisited.begin();
722 origins[current] =
nullptr;
725 unvisited.erase(current);
726 visiting.insert(current);
729 std::vector<QualifiedName> subresult =
730 findInlineCycle(precedenceGraph, origins, current, unvisited, visiting, visited);
732 if (subresult.empty()) {
734 return findInlineCycle(precedenceGraph, origins,
nullptr, unvisited, visiting, visited);
742 const RelationSet& successors = precedenceGraph.graph().successors(current);
743 for (
const Relation* successor : successors) {
746 if (visited.find(successor) != visited.end()) {
751 if (visiting.find(successor) != visiting.end()) {
754 while (current !=
nullptr) {
755 result.push_back(current->getQualifiedName());
756 current = origins[current];
762 origins[successor] = current;
765 unvisited.erase(successor);
766 visiting.insert(successor);
769 std::vector<QualifiedName> subgraphCycle =
770 findInlineCycle(precedenceGraph, origins, successor, unvisited, visiting, visited);
772 if (!subgraphCycle.empty()) {
774 return subgraphCycle;
780 visiting.erase(current);
781 visited.insert(current);
785 void SemanticCheckerImpl::checkInlining() {
790 for (
const auto&
relation : program.getRelations()) {
795 "IO relation " +
toString(
relation->getQualifiedName()) +
" cannot be inlined",
812 unvisited.insert(
rel);
816 std::map<const Relation*, const Relation*> origins;
818 std::vector<QualifiedName> result =
819 findInlineCycle(precedenceGraph, origins,
nullptr, unvisited, visiting, visited);
822 if (!result.empty()) {
826 std::stringstream cycle;
827 cycle <<
"{" << cycleOrigin->getQualifiedName();
830 for (
int i = result.size() - 2;
i >= 0;
i--) {
831 cycle <<
", " << result[
i];
837 "Cannot inline cyclically dependent relations " + cycle.str(), cycleOrigin->getSrcLoc());
845 Relation* associatedRelation =
getRelation(program, atom.getQualifiedName());
846 if (associatedRelation !=
nullptr && isInline(associatedRelation)) {
848 if (isA<Counter>(&arg)) {
850 "Cannot inline literal containing a counter argument '$'", arg.getSrcLoc());
857 for (
const Relation*
rel : inlinedRelations) {
860 if (isA<Counter>(&arg)) {
862 "Cannot inline clause containing a counter argument '$'", arg.getSrcLoc());
875 for (
const Relation*
rel : inlinedRelations) {
876 bool foundNonNegatable =
false;
879 std::set<std::string> headVariables;
881 [&](
const ast::Variable& var) { headVariables.insert(var.getName()); });
884 std::set<std::string> bodyVariables;
886 [&](
const ast::Variable& var) { bodyVariables.insert(var.getName()); });
890 for (
const std::string& var : bodyVariables) {
891 if (headVariables.find(var) == headVariables.end()) {
892 nonNegatableRelations.insert(
rel);
893 foundNonNegatable =
true;
898 if (foundNonNegatable) {
906 Relation* associatedRelation =
getRelation(program, neg.getAtom()->getQualifiedName());
907 if (associatedRelation !=
nullptr &&
908 nonNegatableRelations.find(associatedRelation) != nonNegatableRelations.end()) {
910 "Cannot inline negated relation which may introduce new variables", neg.getSrcLoc());
928 const Relation*
rel =
getRelation(program, subatom.getQualifiedName());
929 if (
rel !=
nullptr && isInline(
rel)) {
930 report.
addError(
"Cannot inline relations that appear in aggregator", subatom.getSrcLoc());
951 std::function<std::pair<bool, SrcLocation>(
const Node*)> checkInvalidUnderscore = [&](
const Node* node) {
952 if (isA<UnnamedVariable>(node)) {
954 return std::make_pair(
true, node->getSrcLoc());
955 }
else if (isA<Aggregator>(node)) {
957 return std::make_pair(
false, node->getSrcLoc());
961 for (
const Node* child : node->getChildNodes()) {
962 std::pair<bool, SrcLocation> childStatus = checkInvalidUnderscore(child);
963 if (childStatus.first) {
969 return std::make_pair(
false, node->getSrcLoc());
974 const Atom* associatedAtom = negation.getAtom();
975 const Relation* associatedRelation =
getRelation(program, associatedAtom->getQualifiedName());
976 if (associatedRelation !=
nullptr && isInline(associatedRelation)) {
977 std::pair<bool, SrcLocation> atomStatus = checkInvalidUnderscore(associatedAtom);
978 if (atomStatus.first) {
980 "Cannot inline negated atom containing an unnamed variable unless the variable is "
981 "within an aggregator",
989 void SemanticCheckerImpl::checkNamespaces() {
990 std::map<std::string, SrcLocation> names;
993 for (
const auto&
type : program.getTypes()) {
994 const std::string name =
toString(
type->getQualifiedName());
995 if (names.count(name) != 0u) {
996 report.
addError(
"Name clash on type " + name,
type->getSrcLoc());
998 names[name] =
type->getSrcLoc();
1002 for (
const auto&
rel : program.getRelations()) {
1003 const std::string name =
toString(
rel->getQualifiedName());
1004 if (names.count(name) != 0u) {
1005 report.
addError(
"Name clash on relation " + name,
rel->getSrcLoc());
1007 names[name] =
rel->getSrcLoc();