63 const Clause* clause,
const std::map<const Argument*, TypeSet> argumentTypes) {
65 struct TypeAnnotator :
public NodeMapper {
66 const std::map<const Argument*, TypeSet>&
types;
72 std::stringstream newVarName;
73 newVarName << var->getName() <<
"∈" <<
types.find(var)->second;
74 return mk<ast::Variable>(newVarName.str());
76 std::stringstream newVarName;
78 <<
"∈" <<
types.find(var)->second;
79 return mk<ast::Variable>(newVarName.str());
99 std::map<const Argument*, const Argument*> memoryMap;
101 std::vector<const Argument*> originalAddresses;
102 visitDepthFirst(*clause, [&](
const Argument& arg) { originalAddresses.push_back(&arg); });
104 std::vector<const Argument*> cloneAddresses;
105 visitDepthFirst(*annotatedClause, [&](
const Argument& arg) { cloneAddresses.push_back(&arg); });
107 assert(cloneAddresses.size() == originalAddresses.size());
109 for (
size_t i = 0;
i < originalAddresses.size();
i++) {
110 memoryMap[originalAddresses[
i]] = cloneAddresses[
i];
114 std::map<const Argument*, TypeSet> cloneArgumentTypes;
116 cloneArgumentTypes[memoryMap[pair.first]] = pair.second;
120 TypeAnnotator annotator(cloneArgumentTypes);
121 annotatedClause->apply(annotator);
122 return annotatedClause;
126 const TranslationUnit& tu,
const Clause& clause, std::ostream* logs) {
127 return TypeConstraintsAnalysis(tu).analyse(clause, logs);
131 os <<
"-- Analysis logs --" << std::endl;
133 os <<
"-- Result --" << std::endl;
135 os << *cur << std::endl;
141 if (
auto* intrinsic = as<IntrinsicFunctor>(functor)) {
143 }
else if (
const auto* udf = as<UserDefinedFunctor>(functor)) {
146 fatal(
"Missing functor type.");
151 if (
auto* intrinsic = as<IntrinsicFunctor>(functor)) {
153 return info->params.at(info->variadic ? 0 : idx);
154 }
else if (
auto* udf = as<UserDefinedFunctor>(functor)) {
157 fatal(
"Missing functor type.");
173 if (isA<UserDefinedFunctor>(functor)) {
175 }
else if (
auto* intrinsic = as<IntrinsicFunctor>(functor)) {
177 assert(!candidates.empty() &&
"at least one op should match");
178 return candidates[0].get().multipleResults;
180 fatal(
"Missing functor type.");
184 std::set<TypeAttribute> typeAttributes;
190 return typeAttributes;
195 return typeAttributes;
207 return typeAttributes;
217 auto argTypes =
map(inf.getArguments(), [&](
const Argument* arg) { return getTypeAttributes(arg); });
220 if (!candidate.variadic && argTypes.size() != candidate.params.size()) {
225 for (
size_t i = 0;
i < argTypes.size(); ++
i) {
226 const auto& expectedType = candidate.params[candidate.variadic ? 0 :
i];
227 if (!
contains(argTypes[
i], expectedType)) {
233 return contains(returnTypes, candidate.result);
235 auto candidates =
filter(functorInfos, isValidOverload);
238 auto comparator = [&](
const IntrinsicFunctorInfo& a,
const IntrinsicFunctorInfo&
b) {
239 if (a.result !=
b.result)
return a.result <
b.result;
240 if (a.variadic !=
b.variadic)
return a.variadic <
b.variadic;
241 return std::lexicographical_compare(
242 a.params.begin(), a.params.end(),
b.params.begin(),
b.params.end());
244 std::sort(candidates.begin(), candidates.end(),
comparator);
250 if (
auto* inf = as<IntrinsicFunctor>(argument)) {
252 }
else if (
auto* udf = as<UserDefinedFunctor>(argument)) {
254 }
else if (
auto* nc = as<NumericConstant>(argument)) {
256 }
else if (
auto* agg = as<Aggregator>(argument)) {
283 bool changed =
false;
284 const auto& program = translationUnit.
getProgram();
287 if (candidates.empty()) {
297 const auto* curInfo = &candidates.front().get();
308 bool changed =
false;
309 const auto& program = translationUnit.getProgram();
322 setNumericConstantType(numericConstant, numericConstant.getFixedType().value());
350 bool changed =
false;
351 const auto& program = translationUnit.getProgram();
353 auto setAggregatorType = [&](
const Aggregator& agg,
TypeAttribute attr) {
364 auto* targetExpression = agg.getTargetExpression();
365 if (isFloat(targetExpression)) {
366 setAggregatorType(agg, TypeAttribute::Float);
375 "non-overloaded aggr types should always be the base operator");
387 bool changed =
false;
388 const auto& program = translationUnit.getProgram();
390 auto setConstraintType = [&](
const BinaryConstraint& bc,
TypeAttribute attr) {
402 auto* leftArg = binaryConstraint.getLHS();
403 auto* rightArg = binaryConstraint.getRHS();
406 if (isFloat(leftArg) && isFloat(rightArg)) {
407 setConstraintType(binaryConstraint, TypeAttribute::Float);
418 "unexpected constraint type");
443 std::ostream* debugStream =
nullptr;
461 auto clauseArgumentTypes =
analyseTypes(translationUnit, *clause, debugStream);
462 argumentTypes.insert(clauseArgumentTypes.begin(), clauseArgumentTypes.end());
464 if (debugStream !=
nullptr) {